[Feature] Add /tokenize and /detokenize OpenAI compatible endpoints (#9545)
This commit is contained in:
committed by
GitHub
parent
edd86b8853
commit
7c3f07dbcb
@@ -21,6 +21,8 @@
|
|||||||
"- `/start_expert_distribution_record`\n",
|
"- `/start_expert_distribution_record`\n",
|
||||||
"- `/stop_expert_distribution_record`\n",
|
"- `/stop_expert_distribution_record`\n",
|
||||||
"- `/dump_expert_distribution_record`\n",
|
"- `/dump_expert_distribution_record`\n",
|
||||||
|
"- `/tokenize`\n",
|
||||||
|
"- `/detokenize`\n",
|
||||||
"- A full list of these APIs can be found at [http_server.py](https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/entrypoints/http_server.py)\n",
|
"- A full list of these APIs can be found at [http_server.py](https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/entrypoints/http_server.py)\n",
|
||||||
"\n",
|
"\n",
|
||||||
"We mainly use `requests` to test these APIs in the following examples. You can also use `curl`.\n"
|
"We mainly use `requests` to test these APIs in the following examples. You can also use `curl`.\n"
|
||||||
@@ -477,6 +479,104 @@
|
|||||||
"source": [
|
"source": [
|
||||||
"terminate_process(expert_record_server_process)"
|
"terminate_process(expert_record_server_process)"
|
||||||
]
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Tokenize/Detokenize Example (Round Trip)\n",
|
||||||
|
"\n",
|
||||||
|
"This example demonstrates how to use the /tokenize and /detokenize endpoints together. We first tokenize a string, then detokenize the resulting IDs to reconstruct the original text. This workflow is useful when you need to handle tokenization externally but still leverage the server for detokenization."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"tokenizer_free_server_process, port = launch_server_cmd(\n",
|
||||||
|
" \"\"\"\n",
|
||||||
|
"python3 -m sglang.launch_server --model-path qwen/qwen2.5-0.5b-instruct\n",
|
||||||
|
"\"\"\"\n",
|
||||||
|
")\n",
|
||||||
|
"\n",
|
||||||
|
"wait_for_server(f\"http://localhost:{port}\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"import requests\n",
|
||||||
|
"from sglang.utils import print_highlight\n",
|
||||||
|
"\n",
|
||||||
|
"base_url = f\"http://localhost:{port}\"\n",
|
||||||
|
"tokenize_url = f\"{base_url}/tokenize\"\n",
|
||||||
|
"detokenize_url = f\"{base_url}/detokenize\"\n",
|
||||||
|
"\n",
|
||||||
|
"model_name = \"qwen/qwen2.5-0.5b-instruct\"\n",
|
||||||
|
"input_text = \"SGLang provides efficient tokenization endpoints.\"\n",
|
||||||
|
"print_highlight(f\"Original Input Text:\\n'{input_text}'\")\n",
|
||||||
|
"\n",
|
||||||
|
"# --- tokenize the input text ---\n",
|
||||||
|
"tokenize_payload = {\n",
|
||||||
|
" \"model\": model_name,\n",
|
||||||
|
" \"prompt\": input_text,\n",
|
||||||
|
" \"add_special_tokens\": False,\n",
|
||||||
|
"}\n",
|
||||||
|
"try:\n",
|
||||||
|
" tokenize_response = requests.post(tokenize_url, json=tokenize_payload)\n",
|
||||||
|
" tokenize_response.raise_for_status()\n",
|
||||||
|
" tokenization_result = tokenize_response.json()\n",
|
||||||
|
" token_ids = tokenization_result.get(\"tokens\")\n",
|
||||||
|
"\n",
|
||||||
|
" if not token_ids:\n",
|
||||||
|
" raise ValueError(\"Tokenization returned empty tokens.\")\n",
|
||||||
|
"\n",
|
||||||
|
" print_highlight(f\"\\nTokenized Output (IDs):\\n{token_ids}\")\n",
|
||||||
|
" print_highlight(f\"Token Count: {tokenization_result.get('count')}\")\n",
|
||||||
|
" print_highlight(f\"Max Model Length: {tokenization_result.get('max_model_len')}\")\n",
|
||||||
|
"\n",
|
||||||
|
" # --- detokenize the obtained token IDs ---\n",
|
||||||
|
" detokenize_payload = {\n",
|
||||||
|
" \"model\": model_name,\n",
|
||||||
|
" \"tokens\": token_ids,\n",
|
||||||
|
" \"skip_special_tokens\": True,\n",
|
||||||
|
" }\n",
|
||||||
|
"\n",
|
||||||
|
" detokenize_response = requests.post(detokenize_url, json=detokenize_payload)\n",
|
||||||
|
" detokenize_response.raise_for_status()\n",
|
||||||
|
" detokenization_result = detokenize_response.json()\n",
|
||||||
|
" reconstructed_text = detokenization_result.get(\"text\")\n",
|
||||||
|
"\n",
|
||||||
|
" print_highlight(f\"\\nDetokenized Output (Text):\\n'{reconstructed_text}'\")\n",
|
||||||
|
"\n",
|
||||||
|
" if input_text == reconstructed_text:\n",
|
||||||
|
" print_highlight(\n",
|
||||||
|
" \"\\nRound Trip Successful: Original and reconstructed text match.\"\n",
|
||||||
|
" )\n",
|
||||||
|
" else:\n",
|
||||||
|
" print_highlight(\n",
|
||||||
|
" \"\\nRound Trip Mismatch: Original and reconstructed text differ.\"\n",
|
||||||
|
" )\n",
|
||||||
|
"\n",
|
||||||
|
"except requests.exceptions.RequestException as e:\n",
|
||||||
|
" print_highlight(f\"\\nHTTP Request Error: {e}\")\n",
|
||||||
|
"except Exception as e:\n",
|
||||||
|
" print_highlight(f\"\\nAn error occurred: {e}\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"terminate_process(tokenizer_free_server_process)"
|
||||||
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"metadata": {
|
"metadata": {
|
||||||
@@ -493,5 +593,5 @@
|
|||||||
}
|
}
|
||||||
},
|
},
|
||||||
"nbformat": 4,
|
"nbformat": 4,
|
||||||
"nbformat_minor": 2
|
"nbformat_minor": 4
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -52,12 +52,14 @@ from sglang.srt.entrypoints.engine import _launch_subprocesses
|
|||||||
from sglang.srt.entrypoints.openai.protocol import (
|
from sglang.srt.entrypoints.openai.protocol import (
|
||||||
ChatCompletionRequest,
|
ChatCompletionRequest,
|
||||||
CompletionRequest,
|
CompletionRequest,
|
||||||
|
DetokenizeRequest,
|
||||||
EmbeddingRequest,
|
EmbeddingRequest,
|
||||||
ErrorResponse,
|
ErrorResponse,
|
||||||
ModelCard,
|
ModelCard,
|
||||||
ModelList,
|
ModelList,
|
||||||
ResponsesRequest,
|
ResponsesRequest,
|
||||||
ScoringRequest,
|
ScoringRequest,
|
||||||
|
TokenizeRequest,
|
||||||
V1RerankReqInput,
|
V1RerankReqInput,
|
||||||
)
|
)
|
||||||
from sglang.srt.entrypoints.openai.serving_chat import OpenAIServingChat
|
from sglang.srt.entrypoints.openai.serving_chat import OpenAIServingChat
|
||||||
@@ -65,6 +67,10 @@ from sglang.srt.entrypoints.openai.serving_completions import OpenAIServingCompl
|
|||||||
from sglang.srt.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
|
from sglang.srt.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
|
||||||
from sglang.srt.entrypoints.openai.serving_rerank import OpenAIServingRerank
|
from sglang.srt.entrypoints.openai.serving_rerank import OpenAIServingRerank
|
||||||
from sglang.srt.entrypoints.openai.serving_score import OpenAIServingScore
|
from sglang.srt.entrypoints.openai.serving_score import OpenAIServingScore
|
||||||
|
from sglang.srt.entrypoints.openai.serving_tokenize import (
|
||||||
|
OpenAIServingDetokenize,
|
||||||
|
OpenAIServingTokenize,
|
||||||
|
)
|
||||||
from sglang.srt.function_call.function_call_parser import FunctionCallParser
|
from sglang.srt.function_call.function_call_parser import FunctionCallParser
|
||||||
from sglang.srt.managers.io_struct import (
|
from sglang.srt.managers.io_struct import (
|
||||||
AbortReq,
|
AbortReq,
|
||||||
@@ -229,6 +235,12 @@ async def lifespan(fast_api_app: FastAPI):
|
|||||||
fast_api_app.state.openai_serving_rerank = OpenAIServingRerank(
|
fast_api_app.state.openai_serving_rerank = OpenAIServingRerank(
|
||||||
_global_state.tokenizer_manager
|
_global_state.tokenizer_manager
|
||||||
)
|
)
|
||||||
|
fast_api_app.state.openai_serving_tokenize = OpenAIServingTokenize(
|
||||||
|
_global_state.tokenizer_manager
|
||||||
|
)
|
||||||
|
fast_api_app.state.openai_serving_detokenize = OpenAIServingDetokenize(
|
||||||
|
_global_state.tokenizer_manager
|
||||||
|
)
|
||||||
|
|
||||||
server_args: ServerArgs = fast_api_app.server_args
|
server_args: ServerArgs = fast_api_app.server_args
|
||||||
|
|
||||||
@@ -1070,6 +1082,42 @@ async def openai_v1_embeddings(request: EmbeddingRequest, raw_request: Request):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@app.post(
|
||||||
|
"/v1/tokenize",
|
||||||
|
response_class=ORJSONResponse,
|
||||||
|
dependencies=[Depends(validate_json_request)],
|
||||||
|
)
|
||||||
|
@app.post(
|
||||||
|
"/tokenize",
|
||||||
|
response_class=ORJSONResponse,
|
||||||
|
dependencies=[Depends(validate_json_request)],
|
||||||
|
include_in_schema=False,
|
||||||
|
)
|
||||||
|
async def openai_v1_tokenize(request: TokenizeRequest, raw_request: Request):
|
||||||
|
"""OpenAI-compatible tokenization endpoint."""
|
||||||
|
return await raw_request.app.state.openai_serving_tokenize.handle_request(
|
||||||
|
request, raw_request
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@app.post(
|
||||||
|
"/v1/detokenize",
|
||||||
|
response_class=ORJSONResponse,
|
||||||
|
dependencies=[Depends(validate_json_request)],
|
||||||
|
)
|
||||||
|
@app.post(
|
||||||
|
"/detokenize",
|
||||||
|
response_class=ORJSONResponse,
|
||||||
|
dependencies=[Depends(validate_json_request)],
|
||||||
|
include_in_schema=False,
|
||||||
|
)
|
||||||
|
async def openai_v1_detokenize(request: DetokenizeRequest, raw_request: Request):
|
||||||
|
"""OpenAI-compatible detokenization endpoint."""
|
||||||
|
return await raw_request.app.state.openai_serving_detokenize.handle_request(
|
||||||
|
request, raw_request
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@app.get("/v1/models", response_class=ORJSONResponse)
|
@app.get("/v1/models", response_class=ORJSONResponse)
|
||||||
async def available_models():
|
async def available_models():
|
||||||
"""Show available models. OpenAI-compatible endpoint."""
|
"""Show available models. OpenAI-compatible endpoint."""
|
||||||
|
|||||||
@@ -801,12 +801,50 @@ class RerankResponse(BaseModel):
|
|||||||
meta_info: Optional[dict] = None
|
meta_info: Optional[dict] = None
|
||||||
|
|
||||||
|
|
||||||
|
class TokenizeRequest(BaseModel):
|
||||||
|
"""Request schema for the /tokenize endpoint."""
|
||||||
|
|
||||||
|
model: str = DEFAULT_MODEL_NAME
|
||||||
|
prompt: Union[str, List[str]]
|
||||||
|
add_special_tokens: bool = Field(
|
||||||
|
default=True,
|
||||||
|
description="whether to add model-specific special tokens (e.g. BOS/EOS) during encoding.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TokenizeResponse(BaseModel):
|
||||||
|
"""Response schema for the /tokenize endpoint."""
|
||||||
|
|
||||||
|
tokens: Union[List[int], List[List[int]]]
|
||||||
|
count: Union[int, List[int]]
|
||||||
|
max_model_len: int
|
||||||
|
|
||||||
|
|
||||||
|
class DetokenizeRequest(BaseModel):
|
||||||
|
"""Request schema for the /detokenize endpoint."""
|
||||||
|
|
||||||
|
model: str = DEFAULT_MODEL_NAME
|
||||||
|
tokens: Union[List[int], List[List[int]]]
|
||||||
|
skip_special_tokens: bool = Field(
|
||||||
|
default=True,
|
||||||
|
description="whether to exclude special tokens (e.g. padding or EOS) during decoding.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DetokenizeResponse(BaseModel):
|
||||||
|
"""Response schema for the /detokenize endpoint."""
|
||||||
|
|
||||||
|
text: Union[str, List[str]]
|
||||||
|
|
||||||
|
|
||||||
OpenAIServingRequest = Union[
|
OpenAIServingRequest = Union[
|
||||||
ChatCompletionRequest,
|
ChatCompletionRequest,
|
||||||
CompletionRequest,
|
CompletionRequest,
|
||||||
EmbeddingRequest,
|
EmbeddingRequest,
|
||||||
ScoringRequest,
|
ScoringRequest,
|
||||||
V1RerankReqInput,
|
V1RerankReqInput,
|
||||||
|
TokenizeRequest,
|
||||||
|
DetokenizeRequest,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
144
python/sglang/srt/entrypoints/openai/serving_tokenize.py
Normal file
144
python/sglang/srt/entrypoints/openai/serving_tokenize.py
Normal file
@@ -0,0 +1,144 @@
|
|||||||
|
import logging
|
||||||
|
from http import HTTPStatus
|
||||||
|
from typing import List, Union
|
||||||
|
|
||||||
|
from fastapi import Request
|
||||||
|
|
||||||
|
from sglang.srt.entrypoints.openai.protocol import (
|
||||||
|
DetokenizeRequest,
|
||||||
|
DetokenizeResponse,
|
||||||
|
ErrorResponse,
|
||||||
|
TokenizeRequest,
|
||||||
|
TokenizeResponse,
|
||||||
|
)
|
||||||
|
from sglang.srt.entrypoints.openai.serving_base import OpenAIServingBase
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAIServingTokenize(OpenAIServingBase):
|
||||||
|
"""Handler for /v1/tokenize requests"""
|
||||||
|
|
||||||
|
def _request_id_prefix(self) -> str:
|
||||||
|
return "tok-"
|
||||||
|
|
||||||
|
def _convert_to_internal_request(
|
||||||
|
self, request: TokenizeRequest, raw_request: Request
|
||||||
|
) -> tuple[TokenizeRequest, TokenizeRequest]:
|
||||||
|
return request, request
|
||||||
|
|
||||||
|
async def _handle_non_streaming_request(
|
||||||
|
self,
|
||||||
|
adapted_request: TokenizeRequest,
|
||||||
|
request: TokenizeRequest,
|
||||||
|
raw_request: Request,
|
||||||
|
) -> Union[TokenizeResponse, ErrorResponse]:
|
||||||
|
try:
|
||||||
|
tokenizer = self.tokenizer_manager.tokenizer
|
||||||
|
max_model_len = getattr(tokenizer, "model_max_length", -1)
|
||||||
|
|
||||||
|
if isinstance(request.prompt, str):
|
||||||
|
token_ids = tokenizer.encode(
|
||||||
|
request.prompt,
|
||||||
|
add_special_tokens=request.add_special_tokens,
|
||||||
|
)
|
||||||
|
tokens = token_ids
|
||||||
|
count = len(token_ids)
|
||||||
|
elif isinstance(request.prompt, list):
|
||||||
|
token_ids_list = [
|
||||||
|
tokenizer.encode(
|
||||||
|
text, add_special_tokens=request.add_special_tokens
|
||||||
|
)
|
||||||
|
for text in request.prompt
|
||||||
|
]
|
||||||
|
tokens = token_ids_list
|
||||||
|
count = [len(ids) for ids in token_ids_list]
|
||||||
|
else:
|
||||||
|
return self.create_error_response(
|
||||||
|
f"Invalid prompt type: {type(request.prompt)}. Expected str or List[str]."
|
||||||
|
)
|
||||||
|
|
||||||
|
return TokenizeResponse(
|
||||||
|
tokens=tokens, count=count, max_model_len=max_model_len
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Error during tokenization", exc_info=True)
|
||||||
|
return self.create_error_response(
|
||||||
|
f"Internal server error during tokenization: {e}",
|
||||||
|
err_type="InternalServerError",
|
||||||
|
status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAIServingDetokenize(OpenAIServingBase):
|
||||||
|
"""Handler for /v1/detokenize requests"""
|
||||||
|
|
||||||
|
def _request_id_prefix(self) -> str:
|
||||||
|
return "detok-"
|
||||||
|
|
||||||
|
def _convert_to_internal_request(
|
||||||
|
self, request: DetokenizeRequest, raw_request: Request
|
||||||
|
) -> tuple[DetokenizeRequest, DetokenizeRequest]:
|
||||||
|
return request, request
|
||||||
|
|
||||||
|
async def _handle_non_streaming_request(
|
||||||
|
self,
|
||||||
|
adapted_request: DetokenizeRequest,
|
||||||
|
request: DetokenizeRequest,
|
||||||
|
raw_request: Request,
|
||||||
|
) -> Union[DetokenizeResponse, ErrorResponse]:
|
||||||
|
try:
|
||||||
|
tokenizer = self.tokenizer_manager.tokenizer
|
||||||
|
|
||||||
|
if (
|
||||||
|
isinstance(request.tokens, list)
|
||||||
|
and request.tokens
|
||||||
|
and isinstance(request.tokens[0], int)
|
||||||
|
):
|
||||||
|
if not all(isinstance(t, int) for t in request.tokens):
|
||||||
|
return self.create_error_response(
|
||||||
|
"Invalid input: 'tokens' must be a list of integers."
|
||||||
|
)
|
||||||
|
tokens_to_decode = [int(t) for t in request.tokens]
|
||||||
|
text = tokenizer.decode(
|
||||||
|
tokens_to_decode, skip_special_tokens=request.skip_special_tokens
|
||||||
|
)
|
||||||
|
text_out: Union[str, List[str]] = text
|
||||||
|
elif (
|
||||||
|
isinstance(request.tokens, list)
|
||||||
|
and request.tokens
|
||||||
|
and isinstance(request.tokens[0], list)
|
||||||
|
):
|
||||||
|
texts: List[str] = []
|
||||||
|
for token_list in request.tokens:
|
||||||
|
if not all(isinstance(t, int) for t in token_list):
|
||||||
|
return self.create_error_response(
|
||||||
|
f"Invalid input: Sublist in 'tokens' must contain only integers. Found: {token_list}"
|
||||||
|
)
|
||||||
|
decoded_text = tokenizer.decode(
|
||||||
|
[int(t) for t in token_list],
|
||||||
|
skip_special_tokens=request.skip_special_tokens,
|
||||||
|
)
|
||||||
|
texts.append(decoded_text)
|
||||||
|
text_out = texts
|
||||||
|
elif isinstance(request.tokens, list) and not request.tokens:
|
||||||
|
text_out = ""
|
||||||
|
else:
|
||||||
|
return self.create_error_response(
|
||||||
|
f"Invalid tokens type: {type(request.tokens)}. Expected List[int] or List[List[int]]."
|
||||||
|
)
|
||||||
|
|
||||||
|
return DetokenizeResponse(text=text_out)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Error during detokenization", exc_info=True)
|
||||||
|
if "decode" in str(e).lower():
|
||||||
|
return self.create_error_response(
|
||||||
|
f"Error decoding tokens: {e}. Input tokens might be invalid for the model.",
|
||||||
|
err_type="DecodeError",
|
||||||
|
status_code=HTTPStatus.BAD_REQUEST,
|
||||||
|
)
|
||||||
|
return self.create_error_response(
|
||||||
|
f"Internal server error during detokenization: {e}",
|
||||||
|
err_type="InternalServerError",
|
||||||
|
status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
|
||||||
|
)
|
||||||
@@ -1,6 +1,7 @@
|
|||||||
"""
|
"""
|
||||||
python3 -m unittest test_srt_endpoint.TestSRTEndpoint.test_simple_decode
|
python3 -m unittest test_srt_endpoint.TestSRTEndpoint.test_simple_decode
|
||||||
python3 -m unittest test_srt_endpoint.TestSRTEndpoint.test_logprob_with_chunked_prefill
|
python3 -m unittest test_srt_endpoint.TestSRTEndpoint.test_logprob_with_chunked_prefill
|
||||||
|
python3 -m unittest test_srt_endpoint.TestTokenizeDetokenize
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import json
|
import json
|
||||||
@@ -636,5 +637,107 @@ class TestSRTEndpoint(CustomTestCase):
|
|||||||
f.result()
|
f.result()
|
||||||
|
|
||||||
|
|
||||||
|
# -------------------------------------------------------------------------
|
||||||
|
# /tokenize & /detokenize Test Class: TestTokenizeDetokenize
|
||||||
|
# -------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestTokenizeDetokenize(CustomTestCase):
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
|
||||||
|
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||||
|
cls.tokenize_url = f"{cls.base_url}/tokenize"
|
||||||
|
cls.detokenize_url = f"{cls.base_url}/detokenize"
|
||||||
|
cls.session = requests.Session()
|
||||||
|
cls.process = popen_launch_server(
|
||||||
|
cls.model,
|
||||||
|
cls.base_url,
|
||||||
|
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def tearDownClass(cls):
|
||||||
|
kill_process_tree(cls.process.pid)
|
||||||
|
cls.session.close()
|
||||||
|
|
||||||
|
def _post_json(self, url, payload):
|
||||||
|
r = self.session.post(url, json=payload)
|
||||||
|
r.raise_for_status()
|
||||||
|
return r.json()
|
||||||
|
|
||||||
|
def test_tokenize_various_inputs(self):
|
||||||
|
single = "Hello SGLang world! 123 😊, ಪರ್ವತದ ಮೇಲೆ ಹಿಮ."
|
||||||
|
multi = ["First sentence.", "Second, with 中文."]
|
||||||
|
scenarios = [
|
||||||
|
{"prompt": single, "add_special_tokens": True},
|
||||||
|
{"prompt": single, "add_special_tokens": False},
|
||||||
|
{"prompt": multi, "add_special_tokens": True},
|
||||||
|
{"prompt": multi, "add_special_tokens": False},
|
||||||
|
{"prompt": "", "add_special_tokens": False},
|
||||||
|
]
|
||||||
|
for case in scenarios:
|
||||||
|
payload = {"model": self.model, "prompt": case["prompt"]}
|
||||||
|
if "add_special_tokens" in case:
|
||||||
|
payload["add_special_tokens"] = case["add_special_tokens"]
|
||||||
|
resp = self._post_json(self.tokenize_url, payload)
|
||||||
|
tokens = resp["tokens"]
|
||||||
|
count = resp["count"]
|
||||||
|
self.assertIsInstance(tokens, list)
|
||||||
|
if not tokens:
|
||||||
|
self.assertEqual(count, 0)
|
||||||
|
else:
|
||||||
|
if isinstance(tokens[0], list):
|
||||||
|
total = sum(len(t) for t in tokens)
|
||||||
|
expected = sum(count) if isinstance(count, list) else count
|
||||||
|
else:
|
||||||
|
total = len(tokens)
|
||||||
|
expected = count
|
||||||
|
self.assertEqual(total, expected)
|
||||||
|
|
||||||
|
def test_tokenize_invalid_type(self):
|
||||||
|
r = self.session.post(
|
||||||
|
self.tokenize_url, json={"model": self.model, "prompt": 12345}
|
||||||
|
)
|
||||||
|
self.assertEqual(r.status_code, 400)
|
||||||
|
|
||||||
|
def test_detokenize_roundtrip(self):
|
||||||
|
text = "Verify detokenization round trip. यह डिटोकेनाइजेशन है"
|
||||||
|
t0 = self._post_json(
|
||||||
|
self.tokenize_url,
|
||||||
|
{"model": self.model, "prompt": text, "add_special_tokens": False},
|
||||||
|
)["tokens"]
|
||||||
|
t1 = self._post_json(
|
||||||
|
self.tokenize_url,
|
||||||
|
{"model": self.model, "prompt": text, "add_special_tokens": True},
|
||||||
|
)["tokens"]
|
||||||
|
cases = [
|
||||||
|
{"tokens": t0, "skip_special_tokens": True, "expected": text},
|
||||||
|
{"tokens": t1, "skip_special_tokens": True, "expected": text},
|
||||||
|
{"tokens": t1, "skip_special_tokens": False, "expected": None},
|
||||||
|
{"tokens": [], "skip_special_tokens": True, "expected": ""},
|
||||||
|
]
|
||||||
|
for case in cases:
|
||||||
|
payload = {"model": self.model, "tokens": case["tokens"]}
|
||||||
|
if "skip_special_tokens" in case:
|
||||||
|
payload["skip_special_tokens"] = case["skip_special_tokens"]
|
||||||
|
resp = self._post_json(self.detokenize_url, payload)
|
||||||
|
text_out = resp["text"]
|
||||||
|
if case["expected"] is not None:
|
||||||
|
self.assertEqual(text_out, case["expected"])
|
||||||
|
else:
|
||||||
|
self.assertIsInstance(text_out, str)
|
||||||
|
|
||||||
|
def test_detokenize_invalid_tokens(self):
|
||||||
|
r = self.session.post(
|
||||||
|
self.detokenize_url, json={"model": self.model, "tokens": ["a", "b"]}
|
||||||
|
)
|
||||||
|
self.assertEqual(r.status_code, 400)
|
||||||
|
r2 = self.session.post(
|
||||||
|
self.detokenize_url, json={"model": self.model, "tokens": [1, -1, 2]}
|
||||||
|
)
|
||||||
|
self.assertEqual(r2.status_code, 500)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
Reference in New Issue
Block a user