[Performance] Dynamic Batch Tokenizer (#9382)
This commit is contained in:
committed by
GitHub
parent
eca59f96c3
commit
94d0f656fb
@@ -49,6 +49,7 @@ from sglang.srt.hf_transformers_utils import (
|
||||
get_tokenizer_from_processor,
|
||||
)
|
||||
from sglang.srt.lora.lora_registry import LoRARef, LoRARegistry
|
||||
from sglang.srt.managers.async_dynamic_batch_tokenizer import AsyncDynamicbatchTokenizer
|
||||
from sglang.srt.managers.disagg_service import start_disagg_service
|
||||
from sglang.srt.managers.io_struct import (
|
||||
AbortReq,
|
||||
@@ -216,6 +217,18 @@ class TokenizerManager(TokenizerCommunicatorMixin):
|
||||
trust_remote_code=server_args.trust_remote_code,
|
||||
revision=server_args.revision,
|
||||
)
|
||||
# Initialize async dynamic batch tokenizer if enabled (common for both multimodal and non-multimodal)
|
||||
if (
|
||||
server_args.enable_dynamic_batch_tokenizer
|
||||
and not server_args.skip_tokenizer_init
|
||||
):
|
||||
self.async_dynamic_batch_tokenizer = AsyncDynamicbatchTokenizer(
|
||||
self.tokenizer,
|
||||
max_batch_size=server_args.dynamic_batch_tokenizer_batch_size,
|
||||
batch_wait_timeout_s=server_args.dynamic_batch_tokenizer_batch_timeout,
|
||||
)
|
||||
else:
|
||||
self.async_dynamic_batch_tokenizer = None
|
||||
|
||||
# Init inter-process communication
|
||||
context = zmq.asyncio.Context(2)
|
||||
@@ -370,6 +383,144 @@ class TokenizerManager(TokenizerCommunicatorMixin):
|
||||
):
|
||||
yield response
|
||||
|
||||
def _detect_input_format(
|
||||
self, texts: Union[str, List[str]], is_cross_encoder: bool
|
||||
) -> str:
|
||||
"""Detect the format of input texts for proper tokenization handling.
|
||||
|
||||
Returns:
|
||||
- "single_string": Regular single text like "Hello world"
|
||||
- "batch_strings": Regular batch like ["Hello", "World"]
|
||||
- "cross_encoder_pairs": Cross-encoder pairs like [["query", "document"]]
|
||||
"""
|
||||
if isinstance(texts, str):
|
||||
return "single_string"
|
||||
|
||||
if (
|
||||
is_cross_encoder
|
||||
and len(texts) > 0
|
||||
and isinstance(texts[0], list)
|
||||
and len(texts[0]) == 2
|
||||
):
|
||||
return "cross_encoder_pairs"
|
||||
|
||||
return "batch_strings"
|
||||
|
||||
def _prepare_tokenizer_input(
|
||||
self, texts: Union[str, List[str]], input_format: str
|
||||
) -> Union[List[str], List[List[str]]]:
|
||||
"""Prepare input for the tokenizer based on detected format."""
|
||||
if input_format == "single_string":
|
||||
return [texts] # Wrap single string for batch processing
|
||||
elif input_format == "cross_encoder_pairs":
|
||||
return texts # Already in correct format: [["query", "doc"]]
|
||||
else: # batch_strings
|
||||
return texts # Already in correct format: ["text1", "text2"]
|
||||
|
||||
def _extract_tokenizer_results(
|
||||
self,
|
||||
input_ids: List[List[int]],
|
||||
token_type_ids: Optional[List[List[int]]],
|
||||
input_format: str,
|
||||
original_batch_size: int,
|
||||
) -> Union[
|
||||
Tuple[List[int], Optional[List[int]]],
|
||||
Tuple[List[List[int]], Optional[List[List[int]]]],
|
||||
]:
|
||||
"""Extract results from tokenizer output based on input format."""
|
||||
|
||||
# For single inputs (string or single cross-encoder pair), extract first element
|
||||
if (
|
||||
input_format in ["single_string", "cross_encoder_pairs"]
|
||||
and original_batch_size == 1
|
||||
):
|
||||
single_input_ids = input_ids[0] if input_ids else []
|
||||
single_token_type_ids = token_type_ids[0] if token_type_ids else None
|
||||
return single_input_ids, single_token_type_ids
|
||||
|
||||
# For true batches, return as-is
|
||||
return input_ids, token_type_ids
|
||||
|
||||
async def _tokenize_texts(
|
||||
self, texts: Union[str, List[str]], is_cross_encoder: bool = False
|
||||
) -> Union[
|
||||
Tuple[List[int], Optional[List[int]]],
|
||||
Tuple[List[List[int]], Optional[List[List[int]]]],
|
||||
]:
|
||||
"""
|
||||
Tokenize text(s) using the appropriate tokenizer strategy.
|
||||
|
||||
This method handles multiple input formats and chooses between async dynamic
|
||||
batch tokenizer (for single texts only) and regular tokenizer.
|
||||
|
||||
Args:
|
||||
texts: Text input in various formats:
|
||||
|
||||
Regular cases:
|
||||
- Single string: "How are you?"
|
||||
- Batch of strings: ["Hello", "World", "How are you?"]
|
||||
|
||||
Cross-encoder cases (sentence pairs for similarity/ranking):
|
||||
- Single pair: [["query text", "document text"]]
|
||||
- Multiple pairs: [["q1", "d1"], ["q2", "d2"], ["q3", "d3"]]
|
||||
|
||||
is_cross_encoder: Whether to return token_type_ids for cross-encoder models.
|
||||
Enables proper handling of sentence pairs with segment IDs.
|
||||
|
||||
Returns:
|
||||
Single input cases:
|
||||
Tuple[List[int], Optional[List[int]]]: (input_ids, token_type_ids)
|
||||
Example: ([101, 2129, 102], [0, 0, 0]) for single text
|
||||
Example: ([101, 2129, 102, 4068, 102], [0, 0, 0, 1, 1]) for cross-encoder pair
|
||||
|
||||
Batch input cases:
|
||||
Tuple[List[List[int]], Optional[List[List[int]]]]: (batch_input_ids, batch_token_type_ids)
|
||||
Example: ([[101, 2129, 102], [101, 4068, 102]], None) for regular batch
|
||||
|
||||
Note: token_type_ids is None unless is_cross_encoder=True.
|
||||
"""
|
||||
if not texts or self.tokenizer is None:
|
||||
raise ValueError("texts cannot be empty and tokenizer must be initialized")
|
||||
|
||||
# Step 1: Detect input format and prepare for tokenization
|
||||
input_format = self._detect_input_format(texts, is_cross_encoder)
|
||||
tokenizer_input = self._prepare_tokenizer_input(texts, input_format)
|
||||
original_batch_size = len(texts) if not isinstance(texts, str) else 1
|
||||
|
||||
# Step 2: Set up tokenizer arguments
|
||||
tokenizer_kwargs = (
|
||||
{"return_token_type_ids": is_cross_encoder} if is_cross_encoder else {}
|
||||
)
|
||||
|
||||
# Step 3: Choose tokenization strategy
|
||||
use_async_tokenizer = (
|
||||
self.async_dynamic_batch_tokenizer is not None
|
||||
and input_format == "single_string"
|
||||
)
|
||||
|
||||
if use_async_tokenizer:
|
||||
logger.debug("Using async dynamic batch tokenizer for single text")
|
||||
result = await self.async_dynamic_batch_tokenizer.encode(
|
||||
tokenizer_input[0], **tokenizer_kwargs
|
||||
)
|
||||
# Convert to batch format for consistency
|
||||
input_ids = [result["input_ids"]]
|
||||
token_type_ids = (
|
||||
[result["token_type_ids"]]
|
||||
if is_cross_encoder and result.get("token_type_ids")
|
||||
else None
|
||||
)
|
||||
else:
|
||||
logger.debug(f"Using regular tokenizer for {len(tokenizer_input)} inputs")
|
||||
encoded = self.tokenizer(tokenizer_input, **tokenizer_kwargs)
|
||||
input_ids = encoded["input_ids"]
|
||||
token_type_ids = encoded.get("token_type_ids") if is_cross_encoder else None
|
||||
|
||||
# Step 4: Extract results based on input format
|
||||
return self._extract_tokenizer_results(
|
||||
input_ids, token_type_ids, input_format, original_batch_size
|
||||
)
|
||||
|
||||
async def _tokenize_one_request(
|
||||
self,
|
||||
obj: Union[GenerateReqInput, EmbeddingReqInput],
|
||||
@@ -400,14 +551,10 @@ class TokenizerManager(TokenizerCommunicatorMixin):
|
||||
"accept text prompts. Please provide input_ids or re-initialize "
|
||||
"the engine with skip_tokenizer_init=False."
|
||||
)
|
||||
encoded = self.tokenizer(
|
||||
input_text, return_token_type_ids=is_cross_encoder_request
|
||||
)
|
||||
|
||||
input_ids = encoded["input_ids"]
|
||||
if is_cross_encoder_request:
|
||||
input_ids = encoded["input_ids"][0]
|
||||
token_type_ids = encoded.get("token_type_ids", [None])[0]
|
||||
input_ids, token_type_ids = await self._tokenize_texts(
|
||||
input_text, is_cross_encoder_request
|
||||
)
|
||||
|
||||
if self.mm_processor and obj.contains_mm_input():
|
||||
if not isinstance(obj.image_data, list):
|
||||
@@ -582,17 +729,27 @@ class TokenizerManager(TokenizerCommunicatorMixin):
|
||||
requests = [obj[i] for i in range(batch_size)]
|
||||
texts = [req.text for req in requests]
|
||||
|
||||
# Batch tokenize all texts
|
||||
encoded = self.tokenizer(texts)
|
||||
input_ids_list = encoded["input_ids"]
|
||||
# Check if any request is a cross-encoder request
|
||||
is_cross_encoder_request = any(
|
||||
isinstance(req, EmbeddingReqInput) and req.is_cross_encoder_request
|
||||
for req in requests
|
||||
)
|
||||
|
||||
# Batch tokenize all texts using unified method
|
||||
input_ids_list, token_type_ids_list = await self._tokenize_texts(
|
||||
texts, is_cross_encoder_request
|
||||
)
|
||||
|
||||
# Process all requests
|
||||
tokenized_objs = []
|
||||
for i, req in enumerate(requests):
|
||||
self._validate_one_request(obj[i], input_ids_list[i])
|
||||
token_type_ids = (
|
||||
token_type_ids_list[i] if token_type_ids_list is not None else None
|
||||
)
|
||||
tokenized_objs.append(
|
||||
self._create_tokenized_object(
|
||||
req, req.text, input_ids_list[i], None, None
|
||||
req, req.text, input_ids_list[i], None, None, token_type_ids
|
||||
)
|
||||
)
|
||||
logger.debug(f"Completed batch processing for {batch_size} requests")
|
||||
|
||||
Reference in New Issue
Block a user