[Performance] Dynamic Batch Tokenizer (#9382)

This commit is contained in:
Sundara Raman Ramachandran
2025-09-13 10:56:04 -07:00
committed by GitHub
parent eca59f96c3
commit 94d0f656fb
5 changed files with 1041 additions and 11 deletions

View File

@@ -49,6 +49,7 @@ from sglang.srt.hf_transformers_utils import (
get_tokenizer_from_processor,
)
from sglang.srt.lora.lora_registry import LoRARef, LoRARegistry
from sglang.srt.managers.async_dynamic_batch_tokenizer import AsyncDynamicbatchTokenizer
from sglang.srt.managers.disagg_service import start_disagg_service
from sglang.srt.managers.io_struct import (
AbortReq,
@@ -216,6 +217,18 @@ class TokenizerManager(TokenizerCommunicatorMixin):
trust_remote_code=server_args.trust_remote_code,
revision=server_args.revision,
)
# Initialize async dynamic batch tokenizer if enabled (common for both multimodal and non-multimodal)
if (
server_args.enable_dynamic_batch_tokenizer
and not server_args.skip_tokenizer_init
):
self.async_dynamic_batch_tokenizer = AsyncDynamicbatchTokenizer(
self.tokenizer,
max_batch_size=server_args.dynamic_batch_tokenizer_batch_size,
batch_wait_timeout_s=server_args.dynamic_batch_tokenizer_batch_timeout,
)
else:
self.async_dynamic_batch_tokenizer = None
# Init inter-process communication
context = zmq.asyncio.Context(2)
@@ -370,6 +383,144 @@ class TokenizerManager(TokenizerCommunicatorMixin):
):
yield response
def _detect_input_format(
self, texts: Union[str, List[str]], is_cross_encoder: bool
) -> str:
"""Detect the format of input texts for proper tokenization handling.
Returns:
- "single_string": Regular single text like "Hello world"
- "batch_strings": Regular batch like ["Hello", "World"]
- "cross_encoder_pairs": Cross-encoder pairs like [["query", "document"]]
"""
if isinstance(texts, str):
return "single_string"
if (
is_cross_encoder
and len(texts) > 0
and isinstance(texts[0], list)
and len(texts[0]) == 2
):
return "cross_encoder_pairs"
return "batch_strings"
def _prepare_tokenizer_input(
self, texts: Union[str, List[str]], input_format: str
) -> Union[List[str], List[List[str]]]:
"""Prepare input for the tokenizer based on detected format."""
if input_format == "single_string":
return [texts] # Wrap single string for batch processing
elif input_format == "cross_encoder_pairs":
return texts # Already in correct format: [["query", "doc"]]
else: # batch_strings
return texts # Already in correct format: ["text1", "text2"]
def _extract_tokenizer_results(
self,
input_ids: List[List[int]],
token_type_ids: Optional[List[List[int]]],
input_format: str,
original_batch_size: int,
) -> Union[
Tuple[List[int], Optional[List[int]]],
Tuple[List[List[int]], Optional[List[List[int]]]],
]:
"""Extract results from tokenizer output based on input format."""
# For single inputs (string or single cross-encoder pair), extract first element
if (
input_format in ["single_string", "cross_encoder_pairs"]
and original_batch_size == 1
):
single_input_ids = input_ids[0] if input_ids else []
single_token_type_ids = token_type_ids[0] if token_type_ids else None
return single_input_ids, single_token_type_ids
# For true batches, return as-is
return input_ids, token_type_ids
async def _tokenize_texts(
self, texts: Union[str, List[str]], is_cross_encoder: bool = False
) -> Union[
Tuple[List[int], Optional[List[int]]],
Tuple[List[List[int]], Optional[List[List[int]]]],
]:
"""
Tokenize text(s) using the appropriate tokenizer strategy.
This method handles multiple input formats and chooses between async dynamic
batch tokenizer (for single texts only) and regular tokenizer.
Args:
texts: Text input in various formats:
Regular cases:
- Single string: "How are you?"
- Batch of strings: ["Hello", "World", "How are you?"]
Cross-encoder cases (sentence pairs for similarity/ranking):
- Single pair: [["query text", "document text"]]
- Multiple pairs: [["q1", "d1"], ["q2", "d2"], ["q3", "d3"]]
is_cross_encoder: Whether to return token_type_ids for cross-encoder models.
Enables proper handling of sentence pairs with segment IDs.
Returns:
Single input cases:
Tuple[List[int], Optional[List[int]]]: (input_ids, token_type_ids)
Example: ([101, 2129, 102], [0, 0, 0]) for single text
Example: ([101, 2129, 102, 4068, 102], [0, 0, 0, 1, 1]) for cross-encoder pair
Batch input cases:
Tuple[List[List[int]], Optional[List[List[int]]]]: (batch_input_ids, batch_token_type_ids)
Example: ([[101, 2129, 102], [101, 4068, 102]], None) for regular batch
Note: token_type_ids is None unless is_cross_encoder=True.
"""
if not texts or self.tokenizer is None:
raise ValueError("texts cannot be empty and tokenizer must be initialized")
# Step 1: Detect input format and prepare for tokenization
input_format = self._detect_input_format(texts, is_cross_encoder)
tokenizer_input = self._prepare_tokenizer_input(texts, input_format)
original_batch_size = len(texts) if not isinstance(texts, str) else 1
# Step 2: Set up tokenizer arguments
tokenizer_kwargs = (
{"return_token_type_ids": is_cross_encoder} if is_cross_encoder else {}
)
# Step 3: Choose tokenization strategy
use_async_tokenizer = (
self.async_dynamic_batch_tokenizer is not None
and input_format == "single_string"
)
if use_async_tokenizer:
logger.debug("Using async dynamic batch tokenizer for single text")
result = await self.async_dynamic_batch_tokenizer.encode(
tokenizer_input[0], **tokenizer_kwargs
)
# Convert to batch format for consistency
input_ids = [result["input_ids"]]
token_type_ids = (
[result["token_type_ids"]]
if is_cross_encoder and result.get("token_type_ids")
else None
)
else:
logger.debug(f"Using regular tokenizer for {len(tokenizer_input)} inputs")
encoded = self.tokenizer(tokenizer_input, **tokenizer_kwargs)
input_ids = encoded["input_ids"]
token_type_ids = encoded.get("token_type_ids") if is_cross_encoder else None
# Step 4: Extract results based on input format
return self._extract_tokenizer_results(
input_ids, token_type_ids, input_format, original_batch_size
)
async def _tokenize_one_request(
self,
obj: Union[GenerateReqInput, EmbeddingReqInput],
@@ -400,14 +551,10 @@ class TokenizerManager(TokenizerCommunicatorMixin):
"accept text prompts. Please provide input_ids or re-initialize "
"the engine with skip_tokenizer_init=False."
)
encoded = self.tokenizer(
input_text, return_token_type_ids=is_cross_encoder_request
)
input_ids = encoded["input_ids"]
if is_cross_encoder_request:
input_ids = encoded["input_ids"][0]
token_type_ids = encoded.get("token_type_ids", [None])[0]
input_ids, token_type_ids = await self._tokenize_texts(
input_text, is_cross_encoder_request
)
if self.mm_processor and obj.contains_mm_input():
if not isinstance(obj.image_data, list):
@@ -582,17 +729,27 @@ class TokenizerManager(TokenizerCommunicatorMixin):
requests = [obj[i] for i in range(batch_size)]
texts = [req.text for req in requests]
# Batch tokenize all texts
encoded = self.tokenizer(texts)
input_ids_list = encoded["input_ids"]
# Check if any request is a cross-encoder request
is_cross_encoder_request = any(
isinstance(req, EmbeddingReqInput) and req.is_cross_encoder_request
for req in requests
)
# Batch tokenize all texts using unified method
input_ids_list, token_type_ids_list = await self._tokenize_texts(
texts, is_cross_encoder_request
)
# Process all requests
tokenized_objs = []
for i, req in enumerate(requests):
self._validate_one_request(obj[i], input_ids_list[i])
token_type_ids = (
token_type_ids_list[i] if token_type_ids_list is not None else None
)
tokenized_objs.append(
self._create_tokenized_object(
req, req.text, input_ids_list[i], None, None
req, req.text, input_ids_list[i], None, None, token_type_ids
)
)
logger.debug(f"Completed batch processing for {batch_size} requests")