Perform Batch Tokenization. (#5141)
This commit is contained in:
committed by
GitHub
parent
2b3bdc938e
commit
f08154193c
126
benchmark/benchmark_batch/benchmark_tokenizer.py
Normal file
126
benchmark/benchmark_batch/benchmark_tokenizer.py
Normal file
@@ -0,0 +1,126 @@
|
||||
import random
|
||||
import time
|
||||
from statistics import mean
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
# CONFIG
|
||||
TOKENIZER_DIR = (
|
||||
"/shared/public/sharing/fait360brew/training/models/meta-llama/Llama-3.2-3B"
|
||||
)
|
||||
NUM_TOKENS = 20000 # Each prompt should contain this many tokens
|
||||
BATCH_SIZES = [1, 2, 4, 8] # Test different batch sizes
|
||||
NUM_RUNS = 5 # Number of runs for each batch size to get reliable measurements
|
||||
|
||||
|
||||
def generate_random_prompts(num_prompts, num_tokens, tokenizer):
|
||||
"""Generate random prompts with specified token count."""
|
||||
vocab_size = tokenizer.vocab_size
|
||||
all_prompts = []
|
||||
|
||||
print(f"Generating {num_prompts} random prompts with {num_tokens} tokens each...")
|
||||
for i in range(num_prompts):
|
||||
# Generate random token IDs - this directly gives us the exact token count
|
||||
random_token_ids = [
|
||||
random.randint(0, vocab_size - 1) for _ in range(num_tokens)
|
||||
]
|
||||
random_text = tokenizer.decode(
|
||||
random_token_ids, clean_up_tokenization_spaces=True
|
||||
)
|
||||
|
||||
prompt = f"Prompt {i}: {random_text}"
|
||||
tokens = tokenizer.encode(prompt)
|
||||
print(f" Prompt {i}: {len(tokens)} tokens")
|
||||
all_prompts.append(prompt)
|
||||
|
||||
return all_prompts
|
||||
|
||||
|
||||
def benchmark_sequential_vs_batch(prompts, batch_size, tokenizer):
|
||||
"""Compare sequential vs batch tokenization for a given batch size."""
|
||||
|
||||
# Sequential tokenization using encode()
|
||||
sequential_times = []
|
||||
for run in range(NUM_RUNS):
|
||||
batch_prompts = prompts[:batch_size] # Use same prompts for fair comparison
|
||||
|
||||
start_time = time.time()
|
||||
for prompt in batch_prompts:
|
||||
tokens = tokenizer.encode(prompt)
|
||||
sequential_time = (time.time() - start_time) * 1000
|
||||
sequential_times.append(sequential_time)
|
||||
|
||||
# Batch tokenization using tokenizer()
|
||||
batch_times = []
|
||||
for run in range(NUM_RUNS):
|
||||
batch_prompts = prompts[:batch_size] # Use same prompts for fair comparison
|
||||
|
||||
start_time = time.time()
|
||||
tokens = tokenizer(batch_prompts)
|
||||
batch_time = (time.time() - start_time) * 1000
|
||||
batch_times.append(batch_time)
|
||||
|
||||
return {
|
||||
"batch_size": batch_size,
|
||||
"avg_sequential_ms": mean(sequential_times),
|
||||
"avg_batch_ms": mean(batch_times),
|
||||
"speedup_factor": (
|
||||
mean(sequential_times) / mean(batch_times) if mean(batch_times) > 0 else 0
|
||||
),
|
||||
"sequential_runs": sequential_times,
|
||||
"batch_runs": batch_times,
|
||||
}
|
||||
|
||||
|
||||
def main():
|
||||
print("Tokenizer Benchmark: Sequential vs Batch Processing")
|
||||
print("-" * 60)
|
||||
print(f"Tokenizer: {TOKENIZER_DIR}")
|
||||
print(f"Tokens per prompt: {NUM_TOKENS}")
|
||||
print(f"Number of runs per batch size: {NUM_RUNS}")
|
||||
print("-" * 60)
|
||||
|
||||
# Load tokenizer once for all operations
|
||||
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_DIR)
|
||||
|
||||
# The largest batch size determines how many prompts we need
|
||||
max_batch_size = max(BATCH_SIZES)
|
||||
all_prompts = generate_random_prompts(max_batch_size, NUM_TOKENS, tokenizer)
|
||||
|
||||
results = []
|
||||
print("\nRunning benchmark...")
|
||||
|
||||
for batch_size in BATCH_SIZES:
|
||||
print(f"\nBenchmarking batch size: {batch_size}")
|
||||
result = benchmark_sequential_vs_batch(all_prompts, batch_size, tokenizer)
|
||||
results.append(result)
|
||||
|
||||
print(f" Sequential tokenization (encode):")
|
||||
for i, run_time in enumerate(result["sequential_runs"]):
|
||||
print(f" Run {i+1}: {run_time:.2f} ms")
|
||||
print(f" Average: {result['avg_sequential_ms']:.2f} ms")
|
||||
|
||||
print(f" Batch tokenization (tokenizer):")
|
||||
for i, run_time in enumerate(result["batch_runs"]):
|
||||
print(f" Run {i+1}: {run_time:.2f} ms")
|
||||
print(f" Average: {result['avg_batch_ms']:.2f} ms")
|
||||
|
||||
print(f" Speedup factor: {result['speedup_factor']:.2f}x")
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("SUMMARY OF RESULTS")
|
||||
print("=" * 60)
|
||||
print(
|
||||
f"{'Batch Size':<10} {'Sequential (ms)':<18} {'Batch (ms)':<18} {'Speedup':<10}"
|
||||
)
|
||||
print("-" * 60)
|
||||
|
||||
for result in results:
|
||||
print(
|
||||
f"{result['batch_size']:<10} {result['avg_sequential_ms']:.2f} ms{' ' * 8} {result['avg_batch_ms']:.2f} ms{' ' * 8} {result['speedup_factor']:.2f}x"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
random.seed(0)
|
||||
main()
|
||||
Reference in New Issue
Block a user