Perform Batch Tokenization. (#5141)

This commit is contained in:
Sundara Raman Ramachandran
2025-04-20 18:10:37 -07:00
committed by GitHub
parent 2b3bdc938e
commit f08154193c
4 changed files with 434 additions and 30 deletions

View File

@@ -415,6 +415,51 @@ class TokenizerManager:
)
if image_inputs and "input_ids" in image_inputs:
input_ids = image_inputs["input_ids"]
self._validate_token_len(obj, input_ids)
return self._create_tokenized_object(
obj, input_text, input_ids, input_embeds, image_inputs
)
def _validate_token_len(
self, obj: Union[GenerateReqInput, EmbeddingReqInput], input_ids: List[int]
) -> None:
"""Validates that the input token count and the requested token count doesn't exceed the model's context length."""
input_token_num = len(input_ids) if input_ids is not None else 0
# Check if input alone exceeds context length
if input_token_num >= self.context_len:
raise ValueError(
f"The input ({input_token_num} tokens) is longer than the "
f"model's context length ({self.context_len} tokens)."
)
# Check total tokens (input + max_new_tokens)
max_new_tokens = obj.sampling_params.get("max_new_tokens")
if (
max_new_tokens is not None
and (max_new_tokens + input_token_num) >= self.context_len
):
total_tokens = max_new_tokens + input_token_num
error_msg = (
f"Requested token count exceeds the model's maximum context length "
f"of {self.context_len} tokens. You requested a total of {total_tokens} "
f"tokens: {input_token_num} tokens from the input messages and "
f"{max_new_tokens} tokens for the completion. Please reduce the number "
f"of tokens in the input messages or the completion to fit within the limit."
)
raise ValueError(error_msg)
def _create_tokenized_object(
self,
obj: Union[GenerateReqInput, EmbeddingReqInput],
input_text: str,
input_ids: List[int],
input_embeds: Optional[Union[List[float], None]] = None,
image_inputs: Optional[Dict] = None,
) -> Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput]:
"""Create a tokenized request object from common parameters."""
if self.is_generation:
return_logprob = obj.return_logprob
logprob_start_len = obj.logprob_start_len
@@ -424,29 +469,6 @@ class TokenizerManager:
SessionParams(**obj.session_params) if obj.session_params else None
)
input_token_num = len(input_ids) if input_ids is not None else 0
if input_token_num >= self.context_len:
raise ValueError(
f"The input ({input_token_num} tokens) is longer than the "
f"model's context length ({self.context_len} tokens)."
)
if (
obj.sampling_params.get("max_new_tokens") is not None
and obj.sampling_params.get("max_new_tokens") + input_token_num
>= self.context_len
):
raise ValueError(
f"Requested token count exceeds the model's maximum context length "
f"of {self.context_len} tokens. You requested a total of "
f"{obj.sampling_params.get('max_new_tokens') + input_token_num} "
f"tokens: {input_token_num} tokens from the input messages and "
f"{obj.sampling_params.get('max_new_tokens')} tokens for the "
f"completion. Please reduce the number of tokens in the input "
f"messages or the completion to fit within the limit."
)
# Parse sampling parameters
sampling_params = SamplingParams(**obj.sampling_params)
sampling_params.normalize(self.tokenizer)
sampling_params.verify()
@@ -483,6 +505,50 @@ class TokenizerManager:
return tokenized_obj
async def _batch_tokenize_and_process(
self, batch_size: int, obj: Union[GenerateReqInput, EmbeddingReqInput]
) -> List[Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput]]:
"""Handle batch tokenization for text inputs only."""
logger.debug(f"Starting batch tokenization for {batch_size} text requests")
# Collect requests and texts
requests = [obj[i] for i in range(batch_size)]
texts = [req.text for req in requests]
# Batch tokenize all texts
encoded = self.tokenizer(texts)
input_ids_list = encoded["input_ids"]
# Process all requests
tokenized_objs = []
for i, req in enumerate(requests):
self._validate_token_len(obj[i], input_ids_list[i])
tokenized_objs.append(
self._create_tokenized_object(
req, req.text, input_ids_list[i], None, None
)
)
logger.debug(f"Completed batch processing for {batch_size} requests")
return tokenized_objs
def _validate_batch_tokenization_constraints(
self, batch_size: int, obj: Union[GenerateReqInput, EmbeddingReqInput]
) -> None:
"""Validate constraints for batch tokenization processing."""
for i in range(batch_size):
if self.is_generation and obj[i].image_data:
raise ValueError(
"For image input processing do not set `enable_tokenizer_batch_encode`."
)
if obj[i].input_ids is not None:
raise ValueError(
"Batch tokenization is not needed for pre-tokenized input_ids. Do not set `enable_tokenizer_batch_encode`."
)
if obj[i].input_embeds is not None:
raise ValueError(
"Batch tokenization is not needed for input_embeds. Do not set `enable_tokenizer_batch_encode`."
)
def _send_one_request(
self,
obj: Union[GenerateReqInput, EmbeddingReqInput],
@@ -560,14 +626,27 @@ class TokenizerManager:
generators = []
rids = []
if getattr(obj, "parallel_sample_num", 1) == 1:
# Send all requests
for i in range(batch_size):
tmp_obj = obj[i]
tokenized_obj = await self._tokenize_one_request(tmp_obj)
self._send_one_request(tmp_obj, tokenized_obj, created_time)
generators.append(self._wait_one_response(tmp_obj, request))
rids.append(tmp_obj.rid)
if self.server_args.enable_tokenizer_batch_encode:
# Validate batch tokenization constraints
self._validate_batch_tokenization_constraints(batch_size, obj)
tokenized_objs = await self._batch_tokenize_and_process(batch_size, obj)
for i, tokenized_obj in enumerate(tokenized_objs):
tmp_obj = obj[i]
self._send_one_request(tmp_obj, tokenized_obj, created_time)
generators.append(self._wait_one_response(tmp_obj, request))
rids.append(tmp_obj.rid)
else:
# Sequential tokenization and processing
for i in range(batch_size):
tmp_obj = obj[i]
tokenized_obj = await self._tokenize_one_request(tmp_obj)
self._send_one_request(tmp_obj, tokenized_obj, created_time)
generators.append(self._wait_one_response(tmp_obj, request))
rids.append(tmp_obj.rid)
else:
# FIXME: When using batch and parallel_sample_num together, the perf is not optimal.
if batch_size > 128:

View File

@@ -49,6 +49,7 @@ class ServerArgs:
tokenizer_path: Optional[str] = None
tokenizer_mode: str = "auto"
skip_tokenizer_init: bool = False
enable_tokenizer_batch_encode: bool = False
load_format: str = "auto"
trust_remote_code: bool = False
dtype: str = "auto"
@@ -432,6 +433,11 @@ class ServerArgs:
action="store_true",
help="If set, skip init tokenizer and pass input_ids in generate request",
)
parser.add_argument(
"--enable-tokenizer-batch-encode",
action="store_true",
help="Enable batch tokenization for improved performance when processing multiple text inputs. Do not use with image inputs, pre-tokenized input_ids, or input_embeds.",
)
parser.add_argument(
"--load-format",
type=str,