From 7c3f07dbcba5fb36b889ab219a758663f111e599 Mon Sep 17 00:00:00 2001 From: Adarsh Shirawalmath <114558126+adarshxs@users.noreply.github.com> Date: Wed, 8 Oct 2025 10:08:48 +0530 Subject: [PATCH] [Feature] Add /tokenize and /detokenize OpenAI compatible endpoints (#9545) --- docs/basic_usage/native_api.ipynb | 102 ++++++++++++- python/sglang/srt/entrypoints/http_server.py | 48 ++++++ .../sglang/srt/entrypoints/openai/protocol.py | 38 +++++ .../entrypoints/openai/serving_tokenize.py | 144 ++++++++++++++++++ test/srt/test_srt_endpoint.py | 103 +++++++++++++ 5 files changed, 434 insertions(+), 1 deletion(-) create mode 100644 python/sglang/srt/entrypoints/openai/serving_tokenize.py diff --git a/docs/basic_usage/native_api.ipynb b/docs/basic_usage/native_api.ipynb index 5e4ca19a1..a62fa8d18 100644 --- a/docs/basic_usage/native_api.ipynb +++ b/docs/basic_usage/native_api.ipynb @@ -21,6 +21,8 @@ "- `/start_expert_distribution_record`\n", "- `/stop_expert_distribution_record`\n", "- `/dump_expert_distribution_record`\n", + "- `/tokenize`\n", + "- `/detokenize`\n", "- A full list of these APIs can be found at [http_server.py](https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/entrypoints/http_server.py)\n", "\n", "We mainly use `requests` to test these APIs in the following examples. You can also use `curl`.\n" @@ -477,6 +479,104 @@ "source": [ "terminate_process(expert_record_server_process)" ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Tokenize/Detokenize Example (Round Trip)\n", + "\n", + "This example demonstrates how to use the /tokenize and /detokenize endpoints together. We first tokenize a string, then detokenize the resulting IDs to reconstruct the original text. This workflow is useful when you need to handle tokenization externally but still leverage the server for detokenization." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "tokenizer_free_server_process, port = launch_server_cmd(\n", + " \"\"\"\n", + "python3 -m sglang.launch_server --model-path qwen/qwen2.5-0.5b-instruct\n", + "\"\"\"\n", + ")\n", + "\n", + "wait_for_server(f\"http://localhost:{port}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import requests\n", + "from sglang.utils import print_highlight\n", + "\n", + "base_url = f\"http://localhost:{port}\"\n", + "tokenize_url = f\"{base_url}/tokenize\"\n", + "detokenize_url = f\"{base_url}/detokenize\"\n", + "\n", + "model_name = \"qwen/qwen2.5-0.5b-instruct\"\n", + "input_text = \"SGLang provides efficient tokenization endpoints.\"\n", + "print_highlight(f\"Original Input Text:\\n'{input_text}'\")\n", + "\n", + "# --- tokenize the input text ---\n", + "tokenize_payload = {\n", + " \"model\": model_name,\n", + " \"prompt\": input_text,\n", + " \"add_special_tokens\": False,\n", + "}\n", + "try:\n", + " tokenize_response = requests.post(tokenize_url, json=tokenize_payload)\n", + " tokenize_response.raise_for_status()\n", + " tokenization_result = tokenize_response.json()\n", + " token_ids = tokenization_result.get(\"tokens\")\n", + "\n", + " if not token_ids:\n", + " raise ValueError(\"Tokenization returned empty tokens.\")\n", + "\n", + " print_highlight(f\"\\nTokenized Output (IDs):\\n{token_ids}\")\n", + " print_highlight(f\"Token Count: {tokenization_result.get('count')}\")\n", + " print_highlight(f\"Max Model Length: {tokenization_result.get('max_model_len')}\")\n", + "\n", + " # --- detokenize the obtained token IDs ---\n", + " detokenize_payload = {\n", + " \"model\": model_name,\n", + " \"tokens\": token_ids,\n", + " \"skip_special_tokens\": True,\n", + " }\n", + "\n", + " detokenize_response = requests.post(detokenize_url, json=detokenize_payload)\n", + " detokenize_response.raise_for_status()\n", + " detokenization_result = detokenize_response.json()\n", + " reconstructed_text = detokenization_result.get(\"text\")\n", + "\n", + " print_highlight(f\"\\nDetokenized Output (Text):\\n'{reconstructed_text}'\")\n", + "\n", + " if input_text == reconstructed_text:\n", + " print_highlight(\n", + " \"\\nRound Trip Successful: Original and reconstructed text match.\"\n", + " )\n", + " else:\n", + " print_highlight(\n", + " \"\\nRound Trip Mismatch: Original and reconstructed text differ.\"\n", + " )\n", + "\n", + "except requests.exceptions.RequestException as e:\n", + " print_highlight(f\"\\nHTTP Request Error: {e}\")\n", + "except Exception as e:\n", + " print_highlight(f\"\\nAn error occurred: {e}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "terminate_process(tokenizer_free_server_process)" + ] } ], "metadata": { @@ -493,5 +593,5 @@ } }, "nbformat": 4, - "nbformat_minor": 2 + "nbformat_minor": 4 } diff --git a/python/sglang/srt/entrypoints/http_server.py b/python/sglang/srt/entrypoints/http_server.py index 9ba8e6374..c64e309c4 100644 --- a/python/sglang/srt/entrypoints/http_server.py +++ b/python/sglang/srt/entrypoints/http_server.py @@ -52,12 +52,14 @@ from sglang.srt.entrypoints.engine import _launch_subprocesses from sglang.srt.entrypoints.openai.protocol import ( ChatCompletionRequest, CompletionRequest, + DetokenizeRequest, EmbeddingRequest, ErrorResponse, ModelCard, ModelList, ResponsesRequest, ScoringRequest, + TokenizeRequest, V1RerankReqInput, ) from sglang.srt.entrypoints.openai.serving_chat import OpenAIServingChat @@ -65,6 +67,10 @@ from sglang.srt.entrypoints.openai.serving_completions import OpenAIServingCompl from sglang.srt.entrypoints.openai.serving_embedding import OpenAIServingEmbedding from sglang.srt.entrypoints.openai.serving_rerank import OpenAIServingRerank from sglang.srt.entrypoints.openai.serving_score import OpenAIServingScore +from sglang.srt.entrypoints.openai.serving_tokenize import ( + OpenAIServingDetokenize, + OpenAIServingTokenize, +) from sglang.srt.function_call.function_call_parser import FunctionCallParser from sglang.srt.managers.io_struct import ( AbortReq, @@ -229,6 +235,12 @@ async def lifespan(fast_api_app: FastAPI): fast_api_app.state.openai_serving_rerank = OpenAIServingRerank( _global_state.tokenizer_manager ) + fast_api_app.state.openai_serving_tokenize = OpenAIServingTokenize( + _global_state.tokenizer_manager + ) + fast_api_app.state.openai_serving_detokenize = OpenAIServingDetokenize( + _global_state.tokenizer_manager + ) server_args: ServerArgs = fast_api_app.server_args @@ -1070,6 +1082,42 @@ async def openai_v1_embeddings(request: EmbeddingRequest, raw_request: Request): ) +@app.post( + "/v1/tokenize", + response_class=ORJSONResponse, + dependencies=[Depends(validate_json_request)], +) +@app.post( + "/tokenize", + response_class=ORJSONResponse, + dependencies=[Depends(validate_json_request)], + include_in_schema=False, +) +async def openai_v1_tokenize(request: TokenizeRequest, raw_request: Request): + """OpenAI-compatible tokenization endpoint.""" + return await raw_request.app.state.openai_serving_tokenize.handle_request( + request, raw_request + ) + + +@app.post( + "/v1/detokenize", + response_class=ORJSONResponse, + dependencies=[Depends(validate_json_request)], +) +@app.post( + "/detokenize", + response_class=ORJSONResponse, + dependencies=[Depends(validate_json_request)], + include_in_schema=False, +) +async def openai_v1_detokenize(request: DetokenizeRequest, raw_request: Request): + """OpenAI-compatible detokenization endpoint.""" + return await raw_request.app.state.openai_serving_detokenize.handle_request( + request, raw_request + ) + + @app.get("/v1/models", response_class=ORJSONResponse) async def available_models(): """Show available models. OpenAI-compatible endpoint.""" diff --git a/python/sglang/srt/entrypoints/openai/protocol.py b/python/sglang/srt/entrypoints/openai/protocol.py index 3acb791aa..735f6a998 100644 --- a/python/sglang/srt/entrypoints/openai/protocol.py +++ b/python/sglang/srt/entrypoints/openai/protocol.py @@ -801,12 +801,50 @@ class RerankResponse(BaseModel): meta_info: Optional[dict] = None +class TokenizeRequest(BaseModel): + """Request schema for the /tokenize endpoint.""" + + model: str = DEFAULT_MODEL_NAME + prompt: Union[str, List[str]] + add_special_tokens: bool = Field( + default=True, + description="whether to add model-specific special tokens (e.g. BOS/EOS) during encoding.", + ) + + +class TokenizeResponse(BaseModel): + """Response schema for the /tokenize endpoint.""" + + tokens: Union[List[int], List[List[int]]] + count: Union[int, List[int]] + max_model_len: int + + +class DetokenizeRequest(BaseModel): + """Request schema for the /detokenize endpoint.""" + + model: str = DEFAULT_MODEL_NAME + tokens: Union[List[int], List[List[int]]] + skip_special_tokens: bool = Field( + default=True, + description="whether to exclude special tokens (e.g. padding or EOS) during decoding.", + ) + + +class DetokenizeResponse(BaseModel): + """Response schema for the /detokenize endpoint.""" + + text: Union[str, List[str]] + + OpenAIServingRequest = Union[ ChatCompletionRequest, CompletionRequest, EmbeddingRequest, ScoringRequest, V1RerankReqInput, + TokenizeRequest, + DetokenizeRequest, ] diff --git a/python/sglang/srt/entrypoints/openai/serving_tokenize.py b/python/sglang/srt/entrypoints/openai/serving_tokenize.py new file mode 100644 index 000000000..1bf6de97a --- /dev/null +++ b/python/sglang/srt/entrypoints/openai/serving_tokenize.py @@ -0,0 +1,144 @@ +import logging +from http import HTTPStatus +from typing import List, Union + +from fastapi import Request + +from sglang.srt.entrypoints.openai.protocol import ( + DetokenizeRequest, + DetokenizeResponse, + ErrorResponse, + TokenizeRequest, + TokenizeResponse, +) +from sglang.srt.entrypoints.openai.serving_base import OpenAIServingBase + +logger = logging.getLogger(__name__) + + +class OpenAIServingTokenize(OpenAIServingBase): + """Handler for /v1/tokenize requests""" + + def _request_id_prefix(self) -> str: + return "tok-" + + def _convert_to_internal_request( + self, request: TokenizeRequest, raw_request: Request + ) -> tuple[TokenizeRequest, TokenizeRequest]: + return request, request + + async def _handle_non_streaming_request( + self, + adapted_request: TokenizeRequest, + request: TokenizeRequest, + raw_request: Request, + ) -> Union[TokenizeResponse, ErrorResponse]: + try: + tokenizer = self.tokenizer_manager.tokenizer + max_model_len = getattr(tokenizer, "model_max_length", -1) + + if isinstance(request.prompt, str): + token_ids = tokenizer.encode( + request.prompt, + add_special_tokens=request.add_special_tokens, + ) + tokens = token_ids + count = len(token_ids) + elif isinstance(request.prompt, list): + token_ids_list = [ + tokenizer.encode( + text, add_special_tokens=request.add_special_tokens + ) + for text in request.prompt + ] + tokens = token_ids_list + count = [len(ids) for ids in token_ids_list] + else: + return self.create_error_response( + f"Invalid prompt type: {type(request.prompt)}. Expected str or List[str]." + ) + + return TokenizeResponse( + tokens=tokens, count=count, max_model_len=max_model_len + ) + except Exception as e: + logger.error("Error during tokenization", exc_info=True) + return self.create_error_response( + f"Internal server error during tokenization: {e}", + err_type="InternalServerError", + status_code=HTTPStatus.INTERNAL_SERVER_ERROR, + ) + + +class OpenAIServingDetokenize(OpenAIServingBase): + """Handler for /v1/detokenize requests""" + + def _request_id_prefix(self) -> str: + return "detok-" + + def _convert_to_internal_request( + self, request: DetokenizeRequest, raw_request: Request + ) -> tuple[DetokenizeRequest, DetokenizeRequest]: + return request, request + + async def _handle_non_streaming_request( + self, + adapted_request: DetokenizeRequest, + request: DetokenizeRequest, + raw_request: Request, + ) -> Union[DetokenizeResponse, ErrorResponse]: + try: + tokenizer = self.tokenizer_manager.tokenizer + + if ( + isinstance(request.tokens, list) + and request.tokens + and isinstance(request.tokens[0], int) + ): + if not all(isinstance(t, int) for t in request.tokens): + return self.create_error_response( + "Invalid input: 'tokens' must be a list of integers." + ) + tokens_to_decode = [int(t) for t in request.tokens] + text = tokenizer.decode( + tokens_to_decode, skip_special_tokens=request.skip_special_tokens + ) + text_out: Union[str, List[str]] = text + elif ( + isinstance(request.tokens, list) + and request.tokens + and isinstance(request.tokens[0], list) + ): + texts: List[str] = [] + for token_list in request.tokens: + if not all(isinstance(t, int) for t in token_list): + return self.create_error_response( + f"Invalid input: Sublist in 'tokens' must contain only integers. Found: {token_list}" + ) + decoded_text = tokenizer.decode( + [int(t) for t in token_list], + skip_special_tokens=request.skip_special_tokens, + ) + texts.append(decoded_text) + text_out = texts + elif isinstance(request.tokens, list) and not request.tokens: + text_out = "" + else: + return self.create_error_response( + f"Invalid tokens type: {type(request.tokens)}. Expected List[int] or List[List[int]]." + ) + + return DetokenizeResponse(text=text_out) + except Exception as e: + logger.error("Error during detokenization", exc_info=True) + if "decode" in str(e).lower(): + return self.create_error_response( + f"Error decoding tokens: {e}. Input tokens might be invalid for the model.", + err_type="DecodeError", + status_code=HTTPStatus.BAD_REQUEST, + ) + return self.create_error_response( + f"Internal server error during detokenization: {e}", + err_type="InternalServerError", + status_code=HTTPStatus.INTERNAL_SERVER_ERROR, + ) diff --git a/test/srt/test_srt_endpoint.py b/test/srt/test_srt_endpoint.py index 089da355d..59a8c3c46 100644 --- a/test/srt/test_srt_endpoint.py +++ b/test/srt/test_srt_endpoint.py @@ -1,6 +1,7 @@ """ python3 -m unittest test_srt_endpoint.TestSRTEndpoint.test_simple_decode python3 -m unittest test_srt_endpoint.TestSRTEndpoint.test_logprob_with_chunked_prefill +python3 -m unittest test_srt_endpoint.TestTokenizeDetokenize """ import json @@ -636,5 +637,107 @@ class TestSRTEndpoint(CustomTestCase): f.result() +# ------------------------------------------------------------------------- +# /tokenize & /detokenize Test Class: TestTokenizeDetokenize +# ------------------------------------------------------------------------- + + +class TestTokenizeDetokenize(CustomTestCase): + @classmethod + def setUpClass(cls): + cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST + cls.base_url = DEFAULT_URL_FOR_TEST + cls.tokenize_url = f"{cls.base_url}/tokenize" + cls.detokenize_url = f"{cls.base_url}/detokenize" + cls.session = requests.Session() + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + ) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + cls.session.close() + + def _post_json(self, url, payload): + r = self.session.post(url, json=payload) + r.raise_for_status() + return r.json() + + def test_tokenize_various_inputs(self): + single = "Hello SGLang world! 123 😊, ಪರ್ವತದ ಮೇಲೆ ಹಿಮ." + multi = ["First sentence.", "Second, with 中文."] + scenarios = [ + {"prompt": single, "add_special_tokens": True}, + {"prompt": single, "add_special_tokens": False}, + {"prompt": multi, "add_special_tokens": True}, + {"prompt": multi, "add_special_tokens": False}, + {"prompt": "", "add_special_tokens": False}, + ] + for case in scenarios: + payload = {"model": self.model, "prompt": case["prompt"]} + if "add_special_tokens" in case: + payload["add_special_tokens"] = case["add_special_tokens"] + resp = self._post_json(self.tokenize_url, payload) + tokens = resp["tokens"] + count = resp["count"] + self.assertIsInstance(tokens, list) + if not tokens: + self.assertEqual(count, 0) + else: + if isinstance(tokens[0], list): + total = sum(len(t) for t in tokens) + expected = sum(count) if isinstance(count, list) else count + else: + total = len(tokens) + expected = count + self.assertEqual(total, expected) + + def test_tokenize_invalid_type(self): + r = self.session.post( + self.tokenize_url, json={"model": self.model, "prompt": 12345} + ) + self.assertEqual(r.status_code, 400) + + def test_detokenize_roundtrip(self): + text = "Verify detokenization round trip. यह डिटोकेनाइजेशन है" + t0 = self._post_json( + self.tokenize_url, + {"model": self.model, "prompt": text, "add_special_tokens": False}, + )["tokens"] + t1 = self._post_json( + self.tokenize_url, + {"model": self.model, "prompt": text, "add_special_tokens": True}, + )["tokens"] + cases = [ + {"tokens": t0, "skip_special_tokens": True, "expected": text}, + {"tokens": t1, "skip_special_tokens": True, "expected": text}, + {"tokens": t1, "skip_special_tokens": False, "expected": None}, + {"tokens": [], "skip_special_tokens": True, "expected": ""}, + ] + for case in cases: + payload = {"model": self.model, "tokens": case["tokens"]} + if "skip_special_tokens" in case: + payload["skip_special_tokens"] = case["skip_special_tokens"] + resp = self._post_json(self.detokenize_url, payload) + text_out = resp["text"] + if case["expected"] is not None: + self.assertEqual(text_out, case["expected"]) + else: + self.assertIsInstance(text_out, str) + + def test_detokenize_invalid_tokens(self): + r = self.session.post( + self.detokenize_url, json={"model": self.model, "tokens": ["a", "b"]} + ) + self.assertEqual(r.status_code, 400) + r2 = self.session.post( + self.detokenize_url, json={"model": self.model, "tokens": [1, -1, 2]} + ) + self.assertEqual(r2.status_code, 500) + + if __name__ == "__main__": unittest.main()