Perform Batch Tokenization. (#5141)
This commit is contained in:
committed by
GitHub
parent
2b3bdc938e
commit
f08154193c
193
benchmark/benchmark_batch/benchmark_batch.py
Normal file
193
benchmark/benchmark_batch/benchmark_batch.py
Normal file
@@ -0,0 +1,193 @@
|
|||||||
|
import concurrent.futures
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
import time
|
||||||
|
from concurrent.futures import ProcessPoolExecutor
|
||||||
|
from statistics import mean
|
||||||
|
|
||||||
|
import requests
|
||||||
|
from tqdm import tqdm
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
|
from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint
|
||||||
|
|
||||||
|
###############################################################################
|
||||||
|
# CONFIG
|
||||||
|
###############################################################################
|
||||||
|
ENDPOINT_URL = "http://127.0.0.1:30000"
|
||||||
|
TOKENIZER_DIR = "/models/meta-llama/Llama-3.2-3B"
|
||||||
|
|
||||||
|
# Benchmark configurations
|
||||||
|
NUM_REQUESTS = 10 # Total number of requests (each with BATCH_SIZE prompts)
|
||||||
|
NUM_TOKENS = 32000 # Tokens per prompt
|
||||||
|
BATCH_SIZE = 8 # Number of prompts per request
|
||||||
|
GEN_TOKENS = 0 # Tokens to generate per prompt
|
||||||
|
|
||||||
|
|
||||||
|
###############################################################################
|
||||||
|
# REQUEST GENERATION (in parallel)
|
||||||
|
###############################################################################
|
||||||
|
def generate_random_prompt(index, tokenizer_dir, num_tokens):
|
||||||
|
"""Generate a single random prompt with specified token count."""
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir)
|
||||||
|
vocab_size = tokenizer.vocab_size
|
||||||
|
|
||||||
|
def generate_random_text(num_toks):
|
||||||
|
random_token_ids = [random.randint(0, vocab_size - 1) for _ in range(num_toks)]
|
||||||
|
return tokenizer.decode(random_token_ids, clean_up_tokenization_spaces=True)
|
||||||
|
|
||||||
|
random_text = generate_random_text(num_tokens)
|
||||||
|
return f"Prompt {index}: {random_text}"
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_all_prompts(num_requests, batch_size, num_tokens, tokenizer_dir):
|
||||||
|
"""Generate prompts for all requests in parallel."""
|
||||||
|
total_prompts = num_requests * batch_size
|
||||||
|
all_prompts = [None] * total_prompts
|
||||||
|
max_workers = min(os.cpu_count() or 1, total_prompts)
|
||||||
|
|
||||||
|
with ProcessPoolExecutor(max_workers=max_workers) as executor:
|
||||||
|
futures = [
|
||||||
|
executor.submit(generate_random_prompt, i, tokenizer_dir, num_tokens)
|
||||||
|
for i in range(total_prompts)
|
||||||
|
]
|
||||||
|
for future in tqdm(
|
||||||
|
concurrent.futures.as_completed(futures),
|
||||||
|
total=total_prompts,
|
||||||
|
desc="Generating prompts",
|
||||||
|
):
|
||||||
|
index = futures.index(future)
|
||||||
|
all_prompts[index] = future.result()
|
||||||
|
|
||||||
|
batched_prompts = [
|
||||||
|
all_prompts[i * batch_size : (i + 1) * batch_size] for i in range(num_requests)
|
||||||
|
]
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"Generated {total_prompts} prompts with {num_tokens} tokens each, grouped into {num_requests} requests of {batch_size} prompts.\n"
|
||||||
|
)
|
||||||
|
return batched_prompts
|
||||||
|
|
||||||
|
|
||||||
|
###############################################################################
|
||||||
|
# HTTP CALLS
|
||||||
|
###############################################################################
|
||||||
|
def send_batch_request(endpoint, prompts, gen_tokens, request_id):
|
||||||
|
"""Send a batch of prompts to the /generate endpoint synchronously."""
|
||||||
|
sampling_params = {
|
||||||
|
"max_new_tokens": gen_tokens,
|
||||||
|
"temperature": 0.7,
|
||||||
|
"stop": "\n",
|
||||||
|
}
|
||||||
|
data = {"text": prompts, "sampling_params": sampling_params}
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
try:
|
||||||
|
response = requests.post(
|
||||||
|
endpoint.base_url + "/generate", json=data, timeout=3600
|
||||||
|
)
|
||||||
|
if response.status_code != 200:
|
||||||
|
error = response.json()
|
||||||
|
raise RuntimeError(f"Request {request_id} failed: {error}")
|
||||||
|
result = response.json()
|
||||||
|
elapsed_time = (time.time() - start_time) * 1000 # Convert to ms
|
||||||
|
avg_per_prompt = elapsed_time / len(prompts) if prompts else 0
|
||||||
|
return request_id, elapsed_time, avg_per_prompt, True, len(prompts)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[Request] Error for request {request_id}: {e}")
|
||||||
|
return request_id, 0, 0, False, len(prompts)
|
||||||
|
|
||||||
|
|
||||||
|
def run_benchmark(endpoint, batched_prompts, batch_size, gen_tokens):
|
||||||
|
"""Run the benchmark sequentially."""
|
||||||
|
results = []
|
||||||
|
num_requests = len(batched_prompts)
|
||||||
|
|
||||||
|
# Record start time for total latency
|
||||||
|
benchmark_start_time = time.time()
|
||||||
|
|
||||||
|
for i, batch_prompts in enumerate(batched_prompts):
|
||||||
|
request_id = i + 1
|
||||||
|
assert (
|
||||||
|
len(batch_prompts) == batch_size
|
||||||
|
), f"Request {request_id} should have {batch_size} prompts, got {len(batch_prompts)}"
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"[Request] Sending request {request_id}/{num_requests} with {len(batch_prompts)} prompts at {int(time.time()*1000)}"
|
||||||
|
)
|
||||||
|
result = send_batch_request(endpoint, batch_prompts, gen_tokens, request_id)
|
||||||
|
results.append(result)
|
||||||
|
|
||||||
|
# Calculate total latency
|
||||||
|
total_latency = (time.time() - benchmark_start_time) * 1000 # Convert to ms
|
||||||
|
|
||||||
|
return results, total_latency
|
||||||
|
|
||||||
|
|
||||||
|
###############################################################################
|
||||||
|
# RESULTS
|
||||||
|
###############################################################################
|
||||||
|
def process_results(results, total_latency, num_requests):
|
||||||
|
"""Process and display benchmark results."""
|
||||||
|
total_time = 0
|
||||||
|
successful_requests = 0
|
||||||
|
failed_requests = 0
|
||||||
|
request_latencies = []
|
||||||
|
per_prompt_latencies = []
|
||||||
|
total_prompts = 0
|
||||||
|
|
||||||
|
for request_id, elapsed_time, avg_per_prompt, success, batch_size in results:
|
||||||
|
if success:
|
||||||
|
successful_requests += 1
|
||||||
|
total_prompts += batch_size
|
||||||
|
request_latencies.append(elapsed_time)
|
||||||
|
per_prompt_latencies.append(avg_per_prompt)
|
||||||
|
total_time += elapsed_time / 1000 # Convert to seconds
|
||||||
|
else:
|
||||||
|
failed_requests += 1
|
||||||
|
|
||||||
|
avg_request_latency = mean(request_latencies) if request_latencies else 0
|
||||||
|
avg_per_prompt_latency = mean(per_prompt_latencies) if per_prompt_latencies else 0
|
||||||
|
throughput = total_prompts / total_time if total_time > 0 else 0
|
||||||
|
|
||||||
|
print("\nBenchmark Summary:")
|
||||||
|
print(f" Total requests sent: {len(results)}")
|
||||||
|
print(f" Total prompts sent: {total_prompts}")
|
||||||
|
print(f" Successful requests: {successful_requests}")
|
||||||
|
print(f" Failed requests: {failed_requests}")
|
||||||
|
print(f" Total latency (all requests): {total_latency:.2f} ms")
|
||||||
|
print(f" Avg per request latency: {avg_request_latency:.2f} ms")
|
||||||
|
print(f" Avg per prompt latency: {avg_per_prompt_latency:.2f} ms")
|
||||||
|
print(f" Throughput: {throughput:.2f} prompts/second\n")
|
||||||
|
|
||||||
|
|
||||||
|
###############################################################################
|
||||||
|
# MAIN
|
||||||
|
###############################################################################
|
||||||
|
def main():
|
||||||
|
# Initialize endpoint
|
||||||
|
endpoint = RuntimeEndpoint(ENDPOINT_URL)
|
||||||
|
|
||||||
|
# Generate prompts
|
||||||
|
batched_prompts = prepare_all_prompts(
|
||||||
|
NUM_REQUESTS, BATCH_SIZE, NUM_TOKENS, TOKENIZER_DIR
|
||||||
|
)
|
||||||
|
|
||||||
|
# Flush cache before benchmark
|
||||||
|
# endpoint.flush_cache()
|
||||||
|
|
||||||
|
# Run benchmark
|
||||||
|
print(
|
||||||
|
f"Starting benchmark: NUM_TOKENS={NUM_TOKENS}, BATCH_SIZE={BATCH_SIZE}, NUM_REQUESTS={NUM_REQUESTS}\n"
|
||||||
|
)
|
||||||
|
results, total_latency = run_benchmark(
|
||||||
|
endpoint, batched_prompts, BATCH_SIZE, GEN_TOKENS
|
||||||
|
)
|
||||||
|
|
||||||
|
# Process and display results
|
||||||
|
process_results(results, total_latency, NUM_REQUESTS)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
random.seed(0)
|
||||||
|
main()
|
||||||
126
benchmark/benchmark_batch/benchmark_tokenizer.py
Normal file
126
benchmark/benchmark_batch/benchmark_tokenizer.py
Normal file
@@ -0,0 +1,126 @@
|
|||||||
|
import random
|
||||||
|
import time
|
||||||
|
from statistics import mean
|
||||||
|
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
|
# CONFIG
|
||||||
|
TOKENIZER_DIR = (
|
||||||
|
"/shared/public/sharing/fait360brew/training/models/meta-llama/Llama-3.2-3B"
|
||||||
|
)
|
||||||
|
NUM_TOKENS = 20000 # Each prompt should contain this many tokens
|
||||||
|
BATCH_SIZES = [1, 2, 4, 8] # Test different batch sizes
|
||||||
|
NUM_RUNS = 5 # Number of runs for each batch size to get reliable measurements
|
||||||
|
|
||||||
|
|
||||||
|
def generate_random_prompts(num_prompts, num_tokens, tokenizer):
|
||||||
|
"""Generate random prompts with specified token count."""
|
||||||
|
vocab_size = tokenizer.vocab_size
|
||||||
|
all_prompts = []
|
||||||
|
|
||||||
|
print(f"Generating {num_prompts} random prompts with {num_tokens} tokens each...")
|
||||||
|
for i in range(num_prompts):
|
||||||
|
# Generate random token IDs - this directly gives us the exact token count
|
||||||
|
random_token_ids = [
|
||||||
|
random.randint(0, vocab_size - 1) for _ in range(num_tokens)
|
||||||
|
]
|
||||||
|
random_text = tokenizer.decode(
|
||||||
|
random_token_ids, clean_up_tokenization_spaces=True
|
||||||
|
)
|
||||||
|
|
||||||
|
prompt = f"Prompt {i}: {random_text}"
|
||||||
|
tokens = tokenizer.encode(prompt)
|
||||||
|
print(f" Prompt {i}: {len(tokens)} tokens")
|
||||||
|
all_prompts.append(prompt)
|
||||||
|
|
||||||
|
return all_prompts
|
||||||
|
|
||||||
|
|
||||||
|
def benchmark_sequential_vs_batch(prompts, batch_size, tokenizer):
|
||||||
|
"""Compare sequential vs batch tokenization for a given batch size."""
|
||||||
|
|
||||||
|
# Sequential tokenization using encode()
|
||||||
|
sequential_times = []
|
||||||
|
for run in range(NUM_RUNS):
|
||||||
|
batch_prompts = prompts[:batch_size] # Use same prompts for fair comparison
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
for prompt in batch_prompts:
|
||||||
|
tokens = tokenizer.encode(prompt)
|
||||||
|
sequential_time = (time.time() - start_time) * 1000
|
||||||
|
sequential_times.append(sequential_time)
|
||||||
|
|
||||||
|
# Batch tokenization using tokenizer()
|
||||||
|
batch_times = []
|
||||||
|
for run in range(NUM_RUNS):
|
||||||
|
batch_prompts = prompts[:batch_size] # Use same prompts for fair comparison
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
tokens = tokenizer(batch_prompts)
|
||||||
|
batch_time = (time.time() - start_time) * 1000
|
||||||
|
batch_times.append(batch_time)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"batch_size": batch_size,
|
||||||
|
"avg_sequential_ms": mean(sequential_times),
|
||||||
|
"avg_batch_ms": mean(batch_times),
|
||||||
|
"speedup_factor": (
|
||||||
|
mean(sequential_times) / mean(batch_times) if mean(batch_times) > 0 else 0
|
||||||
|
),
|
||||||
|
"sequential_runs": sequential_times,
|
||||||
|
"batch_runs": batch_times,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
print("Tokenizer Benchmark: Sequential vs Batch Processing")
|
||||||
|
print("-" * 60)
|
||||||
|
print(f"Tokenizer: {TOKENIZER_DIR}")
|
||||||
|
print(f"Tokens per prompt: {NUM_TOKENS}")
|
||||||
|
print(f"Number of runs per batch size: {NUM_RUNS}")
|
||||||
|
print("-" * 60)
|
||||||
|
|
||||||
|
# Load tokenizer once for all operations
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_DIR)
|
||||||
|
|
||||||
|
# The largest batch size determines how many prompts we need
|
||||||
|
max_batch_size = max(BATCH_SIZES)
|
||||||
|
all_prompts = generate_random_prompts(max_batch_size, NUM_TOKENS, tokenizer)
|
||||||
|
|
||||||
|
results = []
|
||||||
|
print("\nRunning benchmark...")
|
||||||
|
|
||||||
|
for batch_size in BATCH_SIZES:
|
||||||
|
print(f"\nBenchmarking batch size: {batch_size}")
|
||||||
|
result = benchmark_sequential_vs_batch(all_prompts, batch_size, tokenizer)
|
||||||
|
results.append(result)
|
||||||
|
|
||||||
|
print(f" Sequential tokenization (encode):")
|
||||||
|
for i, run_time in enumerate(result["sequential_runs"]):
|
||||||
|
print(f" Run {i+1}: {run_time:.2f} ms")
|
||||||
|
print(f" Average: {result['avg_sequential_ms']:.2f} ms")
|
||||||
|
|
||||||
|
print(f" Batch tokenization (tokenizer):")
|
||||||
|
for i, run_time in enumerate(result["batch_runs"]):
|
||||||
|
print(f" Run {i+1}: {run_time:.2f} ms")
|
||||||
|
print(f" Average: {result['avg_batch_ms']:.2f} ms")
|
||||||
|
|
||||||
|
print(f" Speedup factor: {result['speedup_factor']:.2f}x")
|
||||||
|
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("SUMMARY OF RESULTS")
|
||||||
|
print("=" * 60)
|
||||||
|
print(
|
||||||
|
f"{'Batch Size':<10} {'Sequential (ms)':<18} {'Batch (ms)':<18} {'Speedup':<10}"
|
||||||
|
)
|
||||||
|
print("-" * 60)
|
||||||
|
|
||||||
|
for result in results:
|
||||||
|
print(
|
||||||
|
f"{result['batch_size']:<10} {result['avg_sequential_ms']:.2f} ms{' ' * 8} {result['avg_batch_ms']:.2f} ms{' ' * 8} {result['speedup_factor']:.2f}x"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
random.seed(0)
|
||||||
|
main()
|
||||||
@@ -415,6 +415,51 @@ class TokenizerManager:
|
|||||||
)
|
)
|
||||||
if image_inputs and "input_ids" in image_inputs:
|
if image_inputs and "input_ids" in image_inputs:
|
||||||
input_ids = image_inputs["input_ids"]
|
input_ids = image_inputs["input_ids"]
|
||||||
|
|
||||||
|
self._validate_token_len(obj, input_ids)
|
||||||
|
return self._create_tokenized_object(
|
||||||
|
obj, input_text, input_ids, input_embeds, image_inputs
|
||||||
|
)
|
||||||
|
|
||||||
|
def _validate_token_len(
|
||||||
|
self, obj: Union[GenerateReqInput, EmbeddingReqInput], input_ids: List[int]
|
||||||
|
) -> None:
|
||||||
|
"""Validates that the input token count and the requested token count doesn't exceed the model's context length."""
|
||||||
|
|
||||||
|
input_token_num = len(input_ids) if input_ids is not None else 0
|
||||||
|
# Check if input alone exceeds context length
|
||||||
|
if input_token_num >= self.context_len:
|
||||||
|
raise ValueError(
|
||||||
|
f"The input ({input_token_num} tokens) is longer than the "
|
||||||
|
f"model's context length ({self.context_len} tokens)."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check total tokens (input + max_new_tokens)
|
||||||
|
max_new_tokens = obj.sampling_params.get("max_new_tokens")
|
||||||
|
if (
|
||||||
|
max_new_tokens is not None
|
||||||
|
and (max_new_tokens + input_token_num) >= self.context_len
|
||||||
|
):
|
||||||
|
total_tokens = max_new_tokens + input_token_num
|
||||||
|
error_msg = (
|
||||||
|
f"Requested token count exceeds the model's maximum context length "
|
||||||
|
f"of {self.context_len} tokens. You requested a total of {total_tokens} "
|
||||||
|
f"tokens: {input_token_num} tokens from the input messages and "
|
||||||
|
f"{max_new_tokens} tokens for the completion. Please reduce the number "
|
||||||
|
f"of tokens in the input messages or the completion to fit within the limit."
|
||||||
|
)
|
||||||
|
raise ValueError(error_msg)
|
||||||
|
|
||||||
|
def _create_tokenized_object(
|
||||||
|
self,
|
||||||
|
obj: Union[GenerateReqInput, EmbeddingReqInput],
|
||||||
|
input_text: str,
|
||||||
|
input_ids: List[int],
|
||||||
|
input_embeds: Optional[Union[List[float], None]] = None,
|
||||||
|
image_inputs: Optional[Dict] = None,
|
||||||
|
) -> Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput]:
|
||||||
|
"""Create a tokenized request object from common parameters."""
|
||||||
|
|
||||||
if self.is_generation:
|
if self.is_generation:
|
||||||
return_logprob = obj.return_logprob
|
return_logprob = obj.return_logprob
|
||||||
logprob_start_len = obj.logprob_start_len
|
logprob_start_len = obj.logprob_start_len
|
||||||
@@ -424,29 +469,6 @@ class TokenizerManager:
|
|||||||
SessionParams(**obj.session_params) if obj.session_params else None
|
SessionParams(**obj.session_params) if obj.session_params else None
|
||||||
)
|
)
|
||||||
|
|
||||||
input_token_num = len(input_ids) if input_ids is not None else 0
|
|
||||||
if input_token_num >= self.context_len:
|
|
||||||
raise ValueError(
|
|
||||||
f"The input ({input_token_num} tokens) is longer than the "
|
|
||||||
f"model's context length ({self.context_len} tokens)."
|
|
||||||
)
|
|
||||||
|
|
||||||
if (
|
|
||||||
obj.sampling_params.get("max_new_tokens") is not None
|
|
||||||
and obj.sampling_params.get("max_new_tokens") + input_token_num
|
|
||||||
>= self.context_len
|
|
||||||
):
|
|
||||||
raise ValueError(
|
|
||||||
f"Requested token count exceeds the model's maximum context length "
|
|
||||||
f"of {self.context_len} tokens. You requested a total of "
|
|
||||||
f"{obj.sampling_params.get('max_new_tokens') + input_token_num} "
|
|
||||||
f"tokens: {input_token_num} tokens from the input messages and "
|
|
||||||
f"{obj.sampling_params.get('max_new_tokens')} tokens for the "
|
|
||||||
f"completion. Please reduce the number of tokens in the input "
|
|
||||||
f"messages or the completion to fit within the limit."
|
|
||||||
)
|
|
||||||
|
|
||||||
# Parse sampling parameters
|
|
||||||
sampling_params = SamplingParams(**obj.sampling_params)
|
sampling_params = SamplingParams(**obj.sampling_params)
|
||||||
sampling_params.normalize(self.tokenizer)
|
sampling_params.normalize(self.tokenizer)
|
||||||
sampling_params.verify()
|
sampling_params.verify()
|
||||||
@@ -483,6 +505,50 @@ class TokenizerManager:
|
|||||||
|
|
||||||
return tokenized_obj
|
return tokenized_obj
|
||||||
|
|
||||||
|
async def _batch_tokenize_and_process(
|
||||||
|
self, batch_size: int, obj: Union[GenerateReqInput, EmbeddingReqInput]
|
||||||
|
) -> List[Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput]]:
|
||||||
|
"""Handle batch tokenization for text inputs only."""
|
||||||
|
logger.debug(f"Starting batch tokenization for {batch_size} text requests")
|
||||||
|
|
||||||
|
# Collect requests and texts
|
||||||
|
requests = [obj[i] for i in range(batch_size)]
|
||||||
|
texts = [req.text for req in requests]
|
||||||
|
|
||||||
|
# Batch tokenize all texts
|
||||||
|
encoded = self.tokenizer(texts)
|
||||||
|
input_ids_list = encoded["input_ids"]
|
||||||
|
|
||||||
|
# Process all requests
|
||||||
|
tokenized_objs = []
|
||||||
|
for i, req in enumerate(requests):
|
||||||
|
self._validate_token_len(obj[i], input_ids_list[i])
|
||||||
|
tokenized_objs.append(
|
||||||
|
self._create_tokenized_object(
|
||||||
|
req, req.text, input_ids_list[i], None, None
|
||||||
|
)
|
||||||
|
)
|
||||||
|
logger.debug(f"Completed batch processing for {batch_size} requests")
|
||||||
|
return tokenized_objs
|
||||||
|
|
||||||
|
def _validate_batch_tokenization_constraints(
|
||||||
|
self, batch_size: int, obj: Union[GenerateReqInput, EmbeddingReqInput]
|
||||||
|
) -> None:
|
||||||
|
"""Validate constraints for batch tokenization processing."""
|
||||||
|
for i in range(batch_size):
|
||||||
|
if self.is_generation and obj[i].image_data:
|
||||||
|
raise ValueError(
|
||||||
|
"For image input processing do not set `enable_tokenizer_batch_encode`."
|
||||||
|
)
|
||||||
|
if obj[i].input_ids is not None:
|
||||||
|
raise ValueError(
|
||||||
|
"Batch tokenization is not needed for pre-tokenized input_ids. Do not set `enable_tokenizer_batch_encode`."
|
||||||
|
)
|
||||||
|
if obj[i].input_embeds is not None:
|
||||||
|
raise ValueError(
|
||||||
|
"Batch tokenization is not needed for input_embeds. Do not set `enable_tokenizer_batch_encode`."
|
||||||
|
)
|
||||||
|
|
||||||
def _send_one_request(
|
def _send_one_request(
|
||||||
self,
|
self,
|
||||||
obj: Union[GenerateReqInput, EmbeddingReqInput],
|
obj: Union[GenerateReqInput, EmbeddingReqInput],
|
||||||
@@ -560,14 +626,27 @@ class TokenizerManager:
|
|||||||
|
|
||||||
generators = []
|
generators = []
|
||||||
rids = []
|
rids = []
|
||||||
|
|
||||||
if getattr(obj, "parallel_sample_num", 1) == 1:
|
if getattr(obj, "parallel_sample_num", 1) == 1:
|
||||||
# Send all requests
|
if self.server_args.enable_tokenizer_batch_encode:
|
||||||
for i in range(batch_size):
|
# Validate batch tokenization constraints
|
||||||
tmp_obj = obj[i]
|
self._validate_batch_tokenization_constraints(batch_size, obj)
|
||||||
tokenized_obj = await self._tokenize_one_request(tmp_obj)
|
|
||||||
self._send_one_request(tmp_obj, tokenized_obj, created_time)
|
tokenized_objs = await self._batch_tokenize_and_process(batch_size, obj)
|
||||||
generators.append(self._wait_one_response(tmp_obj, request))
|
|
||||||
rids.append(tmp_obj.rid)
|
for i, tokenized_obj in enumerate(tokenized_objs):
|
||||||
|
tmp_obj = obj[i]
|
||||||
|
self._send_one_request(tmp_obj, tokenized_obj, created_time)
|
||||||
|
generators.append(self._wait_one_response(tmp_obj, request))
|
||||||
|
rids.append(tmp_obj.rid)
|
||||||
|
else:
|
||||||
|
# Sequential tokenization and processing
|
||||||
|
for i in range(batch_size):
|
||||||
|
tmp_obj = obj[i]
|
||||||
|
tokenized_obj = await self._tokenize_one_request(tmp_obj)
|
||||||
|
self._send_one_request(tmp_obj, tokenized_obj, created_time)
|
||||||
|
generators.append(self._wait_one_response(tmp_obj, request))
|
||||||
|
rids.append(tmp_obj.rid)
|
||||||
else:
|
else:
|
||||||
# FIXME: When using batch and parallel_sample_num together, the perf is not optimal.
|
# FIXME: When using batch and parallel_sample_num together, the perf is not optimal.
|
||||||
if batch_size > 128:
|
if batch_size > 128:
|
||||||
|
|||||||
@@ -49,6 +49,7 @@ class ServerArgs:
|
|||||||
tokenizer_path: Optional[str] = None
|
tokenizer_path: Optional[str] = None
|
||||||
tokenizer_mode: str = "auto"
|
tokenizer_mode: str = "auto"
|
||||||
skip_tokenizer_init: bool = False
|
skip_tokenizer_init: bool = False
|
||||||
|
enable_tokenizer_batch_encode: bool = False
|
||||||
load_format: str = "auto"
|
load_format: str = "auto"
|
||||||
trust_remote_code: bool = False
|
trust_remote_code: bool = False
|
||||||
dtype: str = "auto"
|
dtype: str = "auto"
|
||||||
@@ -432,6 +433,11 @@ class ServerArgs:
|
|||||||
action="store_true",
|
action="store_true",
|
||||||
help="If set, skip init tokenizer and pass input_ids in generate request",
|
help="If set, skip init tokenizer and pass input_ids in generate request",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--enable-tokenizer-batch-encode",
|
||||||
|
action="store_true",
|
||||||
|
help="Enable batch tokenization for improved performance when processing multiple text inputs. Do not use with image inputs, pre-tokenized input_ids, or input_embeds.",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--load-format",
|
"--load-format",
|
||||||
type=str,
|
type=str,
|
||||||
|
|||||||
Reference in New Issue
Block a user