Perform Batch Tokenization. (#5141)
This commit is contained in:
committed by
GitHub
parent
2b3bdc938e
commit
f08154193c
@@ -415,6 +415,51 @@ class TokenizerManager:
|
||||
)
|
||||
if image_inputs and "input_ids" in image_inputs:
|
||||
input_ids = image_inputs["input_ids"]
|
||||
|
||||
self._validate_token_len(obj, input_ids)
|
||||
return self._create_tokenized_object(
|
||||
obj, input_text, input_ids, input_embeds, image_inputs
|
||||
)
|
||||
|
||||
def _validate_token_len(
|
||||
self, obj: Union[GenerateReqInput, EmbeddingReqInput], input_ids: List[int]
|
||||
) -> None:
|
||||
"""Validates that the input token count and the requested token count doesn't exceed the model's context length."""
|
||||
|
||||
input_token_num = len(input_ids) if input_ids is not None else 0
|
||||
# Check if input alone exceeds context length
|
||||
if input_token_num >= self.context_len:
|
||||
raise ValueError(
|
||||
f"The input ({input_token_num} tokens) is longer than the "
|
||||
f"model's context length ({self.context_len} tokens)."
|
||||
)
|
||||
|
||||
# Check total tokens (input + max_new_tokens)
|
||||
max_new_tokens = obj.sampling_params.get("max_new_tokens")
|
||||
if (
|
||||
max_new_tokens is not None
|
||||
and (max_new_tokens + input_token_num) >= self.context_len
|
||||
):
|
||||
total_tokens = max_new_tokens + input_token_num
|
||||
error_msg = (
|
||||
f"Requested token count exceeds the model's maximum context length "
|
||||
f"of {self.context_len} tokens. You requested a total of {total_tokens} "
|
||||
f"tokens: {input_token_num} tokens from the input messages and "
|
||||
f"{max_new_tokens} tokens for the completion. Please reduce the number "
|
||||
f"of tokens in the input messages or the completion to fit within the limit."
|
||||
)
|
||||
raise ValueError(error_msg)
|
||||
|
||||
def _create_tokenized_object(
|
||||
self,
|
||||
obj: Union[GenerateReqInput, EmbeddingReqInput],
|
||||
input_text: str,
|
||||
input_ids: List[int],
|
||||
input_embeds: Optional[Union[List[float], None]] = None,
|
||||
image_inputs: Optional[Dict] = None,
|
||||
) -> Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput]:
|
||||
"""Create a tokenized request object from common parameters."""
|
||||
|
||||
if self.is_generation:
|
||||
return_logprob = obj.return_logprob
|
||||
logprob_start_len = obj.logprob_start_len
|
||||
@@ -424,29 +469,6 @@ class TokenizerManager:
|
||||
SessionParams(**obj.session_params) if obj.session_params else None
|
||||
)
|
||||
|
||||
input_token_num = len(input_ids) if input_ids is not None else 0
|
||||
if input_token_num >= self.context_len:
|
||||
raise ValueError(
|
||||
f"The input ({input_token_num} tokens) is longer than the "
|
||||
f"model's context length ({self.context_len} tokens)."
|
||||
)
|
||||
|
||||
if (
|
||||
obj.sampling_params.get("max_new_tokens") is not None
|
||||
and obj.sampling_params.get("max_new_tokens") + input_token_num
|
||||
>= self.context_len
|
||||
):
|
||||
raise ValueError(
|
||||
f"Requested token count exceeds the model's maximum context length "
|
||||
f"of {self.context_len} tokens. You requested a total of "
|
||||
f"{obj.sampling_params.get('max_new_tokens') + input_token_num} "
|
||||
f"tokens: {input_token_num} tokens from the input messages and "
|
||||
f"{obj.sampling_params.get('max_new_tokens')} tokens for the "
|
||||
f"completion. Please reduce the number of tokens in the input "
|
||||
f"messages or the completion to fit within the limit."
|
||||
)
|
||||
|
||||
# Parse sampling parameters
|
||||
sampling_params = SamplingParams(**obj.sampling_params)
|
||||
sampling_params.normalize(self.tokenizer)
|
||||
sampling_params.verify()
|
||||
@@ -483,6 +505,50 @@ class TokenizerManager:
|
||||
|
||||
return tokenized_obj
|
||||
|
||||
async def _batch_tokenize_and_process(
|
||||
self, batch_size: int, obj: Union[GenerateReqInput, EmbeddingReqInput]
|
||||
) -> List[Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput]]:
|
||||
"""Handle batch tokenization for text inputs only."""
|
||||
logger.debug(f"Starting batch tokenization for {batch_size} text requests")
|
||||
|
||||
# Collect requests and texts
|
||||
requests = [obj[i] for i in range(batch_size)]
|
||||
texts = [req.text for req in requests]
|
||||
|
||||
# Batch tokenize all texts
|
||||
encoded = self.tokenizer(texts)
|
||||
input_ids_list = encoded["input_ids"]
|
||||
|
||||
# Process all requests
|
||||
tokenized_objs = []
|
||||
for i, req in enumerate(requests):
|
||||
self._validate_token_len(obj[i], input_ids_list[i])
|
||||
tokenized_objs.append(
|
||||
self._create_tokenized_object(
|
||||
req, req.text, input_ids_list[i], None, None
|
||||
)
|
||||
)
|
||||
logger.debug(f"Completed batch processing for {batch_size} requests")
|
||||
return tokenized_objs
|
||||
|
||||
def _validate_batch_tokenization_constraints(
|
||||
self, batch_size: int, obj: Union[GenerateReqInput, EmbeddingReqInput]
|
||||
) -> None:
|
||||
"""Validate constraints for batch tokenization processing."""
|
||||
for i in range(batch_size):
|
||||
if self.is_generation and obj[i].image_data:
|
||||
raise ValueError(
|
||||
"For image input processing do not set `enable_tokenizer_batch_encode`."
|
||||
)
|
||||
if obj[i].input_ids is not None:
|
||||
raise ValueError(
|
||||
"Batch tokenization is not needed for pre-tokenized input_ids. Do not set `enable_tokenizer_batch_encode`."
|
||||
)
|
||||
if obj[i].input_embeds is not None:
|
||||
raise ValueError(
|
||||
"Batch tokenization is not needed for input_embeds. Do not set `enable_tokenizer_batch_encode`."
|
||||
)
|
||||
|
||||
def _send_one_request(
|
||||
self,
|
||||
obj: Union[GenerateReqInput, EmbeddingReqInput],
|
||||
@@ -560,14 +626,27 @@ class TokenizerManager:
|
||||
|
||||
generators = []
|
||||
rids = []
|
||||
|
||||
if getattr(obj, "parallel_sample_num", 1) == 1:
|
||||
# Send all requests
|
||||
for i in range(batch_size):
|
||||
tmp_obj = obj[i]
|
||||
tokenized_obj = await self._tokenize_one_request(tmp_obj)
|
||||
self._send_one_request(tmp_obj, tokenized_obj, created_time)
|
||||
generators.append(self._wait_one_response(tmp_obj, request))
|
||||
rids.append(tmp_obj.rid)
|
||||
if self.server_args.enable_tokenizer_batch_encode:
|
||||
# Validate batch tokenization constraints
|
||||
self._validate_batch_tokenization_constraints(batch_size, obj)
|
||||
|
||||
tokenized_objs = await self._batch_tokenize_and_process(batch_size, obj)
|
||||
|
||||
for i, tokenized_obj in enumerate(tokenized_objs):
|
||||
tmp_obj = obj[i]
|
||||
self._send_one_request(tmp_obj, tokenized_obj, created_time)
|
||||
generators.append(self._wait_one_response(tmp_obj, request))
|
||||
rids.append(tmp_obj.rid)
|
||||
else:
|
||||
# Sequential tokenization and processing
|
||||
for i in range(batch_size):
|
||||
tmp_obj = obj[i]
|
||||
tokenized_obj = await self._tokenize_one_request(tmp_obj)
|
||||
self._send_one_request(tmp_obj, tokenized_obj, created_time)
|
||||
generators.append(self._wait_one_response(tmp_obj, request))
|
||||
rids.append(tmp_obj.rid)
|
||||
else:
|
||||
# FIXME: When using batch and parallel_sample_num together, the perf is not optimal.
|
||||
if batch_size > 128:
|
||||
|
||||
@@ -49,6 +49,7 @@ class ServerArgs:
|
||||
tokenizer_path: Optional[str] = None
|
||||
tokenizer_mode: str = "auto"
|
||||
skip_tokenizer_init: bool = False
|
||||
enable_tokenizer_batch_encode: bool = False
|
||||
load_format: str = "auto"
|
||||
trust_remote_code: bool = False
|
||||
dtype: str = "auto"
|
||||
@@ -432,6 +433,11 @@ class ServerArgs:
|
||||
action="store_true",
|
||||
help="If set, skip init tokenizer and pass input_ids in generate request",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--enable-tokenizer-batch-encode",
|
||||
action="store_true",
|
||||
help="Enable batch tokenization for improved performance when processing multiple text inputs. Do not use with image inputs, pre-tokenized input_ids, or input_embeds.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--load-format",
|
||||
type=str,
|
||||
|
||||
Reference in New Issue
Block a user