[Performance] Dynamic Batch Tokenizer (#9382)

This commit is contained in:
Sundara Raman Ramachandran
2025-09-13 10:56:04 -07:00
committed by GitHub
parent eca59f96c3
commit 94d0f656fb
5 changed files with 1041 additions and 11 deletions

View File

@@ -0,0 +1,170 @@
"""
Asynchronous dynamic batch tokenizer for SGLang.
This module provides an async tokenizer with dynamic batching capabilities
to reduce tokenization overhead when multiple requests arrive concurrently.
"""
import asyncio
import logging
from concurrent.futures import ThreadPoolExecutor
from functools import partial
from typing import Any, Dict, List, Optional
logger = logging.getLogger(__name__)
class AsyncDynamicbatchTokenizer:
"""Asynchronous tokenizer with dynamic batching for single string prompts.
Dynamically batches pending encode requests from a queue to reduce overhead.
Only handles single string prompts - regular batch processing of multiple
strings per request should be handled at a higher level.
A single-thread ThreadPoolExecutor is used so the event loop stays responsive.
Note: Uses lazy initialization for asyncio components because this class
is instantiated in TokenizerManager.__init__() before the event loop starts.
"""
def __init__(
self,
tokenizer,
max_batch_size: int = 32,
batch_wait_timeout_s: float = 0.002,
) -> None:
self.tokenizer = tokenizer
self.max_batch_size = max_batch_size
self.batch_wait_timeout_s = batch_wait_timeout_s
# Single queue for all encode requests - initialized lazily
self._queue: Optional[asyncio.Queue] = None
self._batcher_task: Optional[asyncio.Task] = None
# Single-thread executor for blocking tokenizer calls
self._executor = ThreadPoolExecutor(max_workers=1)
self._initialized = False
def _ensure_initialized(self):
"""Lazy initialization of event loop dependent components."""
if not self._initialized:
self._queue = asyncio.Queue()
self._batcher_task = asyncio.create_task(self._dynamic_batch_loop())
self._initialized = True
async def __call__(self, prompt: str, **kwargs) -> Any:
"""Encode a single prompt."""
return await self.encode(prompt, **kwargs)
async def encode(self, prompt: str, **kwargs) -> Any:
"""Encode a single prompt."""
self._ensure_initialized()
result_future: asyncio.Future = asyncio.get_running_loop().create_future()
await self._queue.put((prompt, kwargs, result_future))
return await result_future
async def _dynamic_batch_loop(self):
"""Dynamically batch incoming encode requests for efficiency."""
while True:
try:
# Get the first request
prompt, kwargs, result_future = await self._queue.get()
# Collect requests into dynamic batch
prompts = [prompt]
kwargs_list = [kwargs]
result_futures = [result_future]
# Check if there are more items immediately available in the queue
# If queue is empty, process single item immediately without timeout
if self._queue.empty():
# No other requests waiting, process immediately
pass
else:
# There might be more requests, wait for dynamic batching opportunity
start_time = asyncio.get_running_loop().time()
# Collect more requests up to max_batch_size or batch_wait_timeout_s
while len(prompts) < self.max_batch_size:
elapsed = asyncio.get_running_loop().time() - start_time
if elapsed >= self.batch_wait_timeout_s:
break
remaining_time = self.batch_wait_timeout_s - elapsed
try:
prompt, kwargs, result_future = await asyncio.wait_for(
self._queue.get(), remaining_time
)
prompts.append(prompt)
kwargs_list.append(kwargs)
result_futures.append(result_future)
except asyncio.TimeoutError:
break
# Log dynamic batch information
logger.debug(
f"AsyncDynamicbatchTokenizer: Processing dynamic batch of size {len(prompts)}"
)
# Process the dynamic batch
await self._process_dynamic_batch(prompts, kwargs_list, result_futures)
except Exception as e:
logger.error(f"Error in dynamic batch loop: {e}")
# Continue the loop to handle other requests
async def _process_dynamic_batch(
self,
prompts: List[str],
kwargs_list: List[Dict],
result_futures: List[asyncio.Future],
) -> None:
"""Process a dynamic batch of encode requests for single string prompts."""
# Check if all kwargs are identical for efficient batch processing
can_batch = len(set(str(sorted(kw.items())) for kw in kwargs_list)) == 1
kwargs = kwargs_list[0] if can_batch else None
try:
# If every request uses identical kwargs we can run a single
# batch tokenizer call for a big speed-up.
if can_batch and len(prompts) > 1:
encode_fn = partial(self.tokenizer, prompts, **kwargs)
results = await asyncio.get_running_loop().run_in_executor(
self._executor, encode_fn
)
for i, fut in enumerate(result_futures):
if not fut.done():
data = {k: v[i] for k, v in results.items()}
fut.set_result(data)
else:
# Process each request individually due to different kwargs
if len(prompts) > 1 and not can_batch:
logger.warning(
f"AsyncDynamicbatchTokenizer: Dynamic batching disabled for batch of {len(prompts)} "
f"requests due to differing kwargs. This reduces performance benefits. "
f"Consider using consistent tokenization parameters across requests."
)
encode_fn = lambda prompts=prompts, kwargs=kwargs_list: [
self.tokenizer(p, **kw) for p, kw in zip(prompts, kwargs_list)
]
results = await asyncio.get_running_loop().run_in_executor(
self._executor, encode_fn
)
for fut, res in zip(result_futures, results):
if not fut.done():
fut.set_result(res)
except Exception as e:
logger.error(f"Error in dynamic batch processing: {e}")
for fut in result_futures:
if not fut.done():
fut.set_exception(e)
def __del__(self):
"""Clean up background tasks."""
if hasattr(self, "_batcher_task") and self._batcher_task:
if not self._batcher_task.done():
self._batcher_task.cancel()
if hasattr(self, "_executor"):
self._executor.shutdown(wait=False)

View File

@@ -49,6 +49,7 @@ from sglang.srt.hf_transformers_utils import (
get_tokenizer_from_processor,
)
from sglang.srt.lora.lora_registry import LoRARef, LoRARegistry
from sglang.srt.managers.async_dynamic_batch_tokenizer import AsyncDynamicbatchTokenizer
from sglang.srt.managers.disagg_service import start_disagg_service
from sglang.srt.managers.io_struct import (
AbortReq,
@@ -216,6 +217,18 @@ class TokenizerManager(TokenizerCommunicatorMixin):
trust_remote_code=server_args.trust_remote_code,
revision=server_args.revision,
)
# Initialize async dynamic batch tokenizer if enabled (common for both multimodal and non-multimodal)
if (
server_args.enable_dynamic_batch_tokenizer
and not server_args.skip_tokenizer_init
):
self.async_dynamic_batch_tokenizer = AsyncDynamicbatchTokenizer(
self.tokenizer,
max_batch_size=server_args.dynamic_batch_tokenizer_batch_size,
batch_wait_timeout_s=server_args.dynamic_batch_tokenizer_batch_timeout,
)
else:
self.async_dynamic_batch_tokenizer = None
# Init inter-process communication
context = zmq.asyncio.Context(2)
@@ -370,6 +383,144 @@ class TokenizerManager(TokenizerCommunicatorMixin):
):
yield response
def _detect_input_format(
self, texts: Union[str, List[str]], is_cross_encoder: bool
) -> str:
"""Detect the format of input texts for proper tokenization handling.
Returns:
- "single_string": Regular single text like "Hello world"
- "batch_strings": Regular batch like ["Hello", "World"]
- "cross_encoder_pairs": Cross-encoder pairs like [["query", "document"]]
"""
if isinstance(texts, str):
return "single_string"
if (
is_cross_encoder
and len(texts) > 0
and isinstance(texts[0], list)
and len(texts[0]) == 2
):
return "cross_encoder_pairs"
return "batch_strings"
def _prepare_tokenizer_input(
self, texts: Union[str, List[str]], input_format: str
) -> Union[List[str], List[List[str]]]:
"""Prepare input for the tokenizer based on detected format."""
if input_format == "single_string":
return [texts] # Wrap single string for batch processing
elif input_format == "cross_encoder_pairs":
return texts # Already in correct format: [["query", "doc"]]
else: # batch_strings
return texts # Already in correct format: ["text1", "text2"]
def _extract_tokenizer_results(
self,
input_ids: List[List[int]],
token_type_ids: Optional[List[List[int]]],
input_format: str,
original_batch_size: int,
) -> Union[
Tuple[List[int], Optional[List[int]]],
Tuple[List[List[int]], Optional[List[List[int]]]],
]:
"""Extract results from tokenizer output based on input format."""
# For single inputs (string or single cross-encoder pair), extract first element
if (
input_format in ["single_string", "cross_encoder_pairs"]
and original_batch_size == 1
):
single_input_ids = input_ids[0] if input_ids else []
single_token_type_ids = token_type_ids[0] if token_type_ids else None
return single_input_ids, single_token_type_ids
# For true batches, return as-is
return input_ids, token_type_ids
async def _tokenize_texts(
self, texts: Union[str, List[str]], is_cross_encoder: bool = False
) -> Union[
Tuple[List[int], Optional[List[int]]],
Tuple[List[List[int]], Optional[List[List[int]]]],
]:
"""
Tokenize text(s) using the appropriate tokenizer strategy.
This method handles multiple input formats and chooses between async dynamic
batch tokenizer (for single texts only) and regular tokenizer.
Args:
texts: Text input in various formats:
Regular cases:
- Single string: "How are you?"
- Batch of strings: ["Hello", "World", "How are you?"]
Cross-encoder cases (sentence pairs for similarity/ranking):
- Single pair: [["query text", "document text"]]
- Multiple pairs: [["q1", "d1"], ["q2", "d2"], ["q3", "d3"]]
is_cross_encoder: Whether to return token_type_ids for cross-encoder models.
Enables proper handling of sentence pairs with segment IDs.
Returns:
Single input cases:
Tuple[List[int], Optional[List[int]]]: (input_ids, token_type_ids)
Example: ([101, 2129, 102], [0, 0, 0]) for single text
Example: ([101, 2129, 102, 4068, 102], [0, 0, 0, 1, 1]) for cross-encoder pair
Batch input cases:
Tuple[List[List[int]], Optional[List[List[int]]]]: (batch_input_ids, batch_token_type_ids)
Example: ([[101, 2129, 102], [101, 4068, 102]], None) for regular batch
Note: token_type_ids is None unless is_cross_encoder=True.
"""
if not texts or self.tokenizer is None:
raise ValueError("texts cannot be empty and tokenizer must be initialized")
# Step 1: Detect input format and prepare for tokenization
input_format = self._detect_input_format(texts, is_cross_encoder)
tokenizer_input = self._prepare_tokenizer_input(texts, input_format)
original_batch_size = len(texts) if not isinstance(texts, str) else 1
# Step 2: Set up tokenizer arguments
tokenizer_kwargs = (
{"return_token_type_ids": is_cross_encoder} if is_cross_encoder else {}
)
# Step 3: Choose tokenization strategy
use_async_tokenizer = (
self.async_dynamic_batch_tokenizer is not None
and input_format == "single_string"
)
if use_async_tokenizer:
logger.debug("Using async dynamic batch tokenizer for single text")
result = await self.async_dynamic_batch_tokenizer.encode(
tokenizer_input[0], **tokenizer_kwargs
)
# Convert to batch format for consistency
input_ids = [result["input_ids"]]
token_type_ids = (
[result["token_type_ids"]]
if is_cross_encoder and result.get("token_type_ids")
else None
)
else:
logger.debug(f"Using regular tokenizer for {len(tokenizer_input)} inputs")
encoded = self.tokenizer(tokenizer_input, **tokenizer_kwargs)
input_ids = encoded["input_ids"]
token_type_ids = encoded.get("token_type_ids") if is_cross_encoder else None
# Step 4: Extract results based on input format
return self._extract_tokenizer_results(
input_ids, token_type_ids, input_format, original_batch_size
)
async def _tokenize_one_request(
self,
obj: Union[GenerateReqInput, EmbeddingReqInput],
@@ -400,14 +551,10 @@ class TokenizerManager(TokenizerCommunicatorMixin):
"accept text prompts. Please provide input_ids or re-initialize "
"the engine with skip_tokenizer_init=False."
)
encoded = self.tokenizer(
input_text, return_token_type_ids=is_cross_encoder_request
)
input_ids = encoded["input_ids"]
if is_cross_encoder_request:
input_ids = encoded["input_ids"][0]
token_type_ids = encoded.get("token_type_ids", [None])[0]
input_ids, token_type_ids = await self._tokenize_texts(
input_text, is_cross_encoder_request
)
if self.mm_processor and obj.contains_mm_input():
if not isinstance(obj.image_data, list):
@@ -582,17 +729,27 @@ class TokenizerManager(TokenizerCommunicatorMixin):
requests = [obj[i] for i in range(batch_size)]
texts = [req.text for req in requests]
# Batch tokenize all texts
encoded = self.tokenizer(texts)
input_ids_list = encoded["input_ids"]
# Check if any request is a cross-encoder request
is_cross_encoder_request = any(
isinstance(req, EmbeddingReqInput) and req.is_cross_encoder_request
for req in requests
)
# Batch tokenize all texts using unified method
input_ids_list, token_type_ids_list = await self._tokenize_texts(
texts, is_cross_encoder_request
)
# Process all requests
tokenized_objs = []
for i, req in enumerate(requests):
self._validate_one_request(obj[i], input_ids_list[i])
token_type_ids = (
token_type_ids_list[i] if token_type_ids_list is not None else None
)
tokenized_objs.append(
self._create_tokenized_object(
req, req.text, input_ids_list[i], None, None
req, req.text, input_ids_list[i], None, None, token_type_ids
)
)
logger.debug(f"Completed batch processing for {batch_size} requests")

View File

@@ -373,6 +373,11 @@ class ServerArgs:
scheduler_recv_interval: int = 1
numa_node: Optional[List[int]] = None
# Dynamic batch tokenizer
enable_dynamic_batch_tokenizer: bool = False
dynamic_batch_tokenizer_batch_size: int = 32
dynamic_batch_tokenizer_batch_timeout: float = 0.002
# Debug tensor dumps
debug_tensor_dump_output_folder: Optional[str] = None
debug_tensor_dump_input_file: Optional[str] = None
@@ -874,6 +879,13 @@ class ServerArgs:
self.disable_cuda_graph = True
logger.warning("Cuda graph is disabled for prefill server")
# Validation: prevent both tokenizer batching features from being enabled
if self.enable_tokenizer_batch_encode and self.enable_dynamic_batch_tokenizer:
raise ValueError(
"Cannot enable both --enable-tokenizer-batch-encode and --enable-dynamic-batch-tokenizer. "
"Please choose one tokenizer batching approach."
)
# Propagate env vars
os.environ["SGLANG_ENABLE_TORCH_COMPILE"] = (
"1" if self.enable_torch_compile else "0"
@@ -2162,6 +2174,23 @@ class ServerArgs:
action="store_true",
help="Only dump the tensors for prefill requests (i.e. batch size > 1).",
)
parser.add_argument(
"--enable-dynamic-batch-tokenizer",
action="store_true",
help="Enable async dynamic batch tokenizer for improved performance when multiple requests arrive concurrently.",
)
parser.add_argument(
"--dynamic-batch-tokenizer-batch-size",
type=int,
default=ServerArgs.dynamic_batch_tokenizer_batch_size,
help="[Only used if --enable-dynamic-batch-tokenizer is set] Maximum batch size for dynamic batch tokenizer.",
)
parser.add_argument(
"--dynamic-batch-tokenizer-batch-timeout",
type=float,
default=ServerArgs.dynamic_batch_tokenizer_batch_timeout,
help="[Only used if --enable-dynamic-batch-tokenizer is set] Timeout in seconds for batching tokenization requests.",
)
# PD disaggregation
parser.add_argument(