diff --git a/benchmark/benchmark_batch/benchmark_batch.py b/benchmark/benchmark_batch/benchmark_batch.py new file mode 100644 index 000000000..15ef0ab6a --- /dev/null +++ b/benchmark/benchmark_batch/benchmark_batch.py @@ -0,0 +1,193 @@ +import concurrent.futures +import os +import random +import time +from concurrent.futures import ProcessPoolExecutor +from statistics import mean + +import requests +from tqdm import tqdm +from transformers import AutoTokenizer + +from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint + +############################################################################### +# CONFIG +############################################################################### +ENDPOINT_URL = "http://127.0.0.1:30000" +TOKENIZER_DIR = "/models/meta-llama/Llama-3.2-3B" + +# Benchmark configurations +NUM_REQUESTS = 10 # Total number of requests (each with BATCH_SIZE prompts) +NUM_TOKENS = 32000 # Tokens per prompt +BATCH_SIZE = 8 # Number of prompts per request +GEN_TOKENS = 0 # Tokens to generate per prompt + + +############################################################################### +# REQUEST GENERATION (in parallel) +############################################################################### +def generate_random_prompt(index, tokenizer_dir, num_tokens): + """Generate a single random prompt with specified token count.""" + tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir) + vocab_size = tokenizer.vocab_size + + def generate_random_text(num_toks): + random_token_ids = [random.randint(0, vocab_size - 1) for _ in range(num_toks)] + return tokenizer.decode(random_token_ids, clean_up_tokenization_spaces=True) + + random_text = generate_random_text(num_tokens) + return f"Prompt {index}: {random_text}" + + +def prepare_all_prompts(num_requests, batch_size, num_tokens, tokenizer_dir): + """Generate prompts for all requests in parallel.""" + total_prompts = num_requests * batch_size + all_prompts = [None] * total_prompts + max_workers = min(os.cpu_count() or 1, total_prompts) + + with ProcessPoolExecutor(max_workers=max_workers) as executor: + futures = [ + executor.submit(generate_random_prompt, i, tokenizer_dir, num_tokens) + for i in range(total_prompts) + ] + for future in tqdm( + concurrent.futures.as_completed(futures), + total=total_prompts, + desc="Generating prompts", + ): + index = futures.index(future) + all_prompts[index] = future.result() + + batched_prompts = [ + all_prompts[i * batch_size : (i + 1) * batch_size] for i in range(num_requests) + ] + + print( + f"Generated {total_prompts} prompts with {num_tokens} tokens each, grouped into {num_requests} requests of {batch_size} prompts.\n" + ) + return batched_prompts + + +############################################################################### +# HTTP CALLS +############################################################################### +def send_batch_request(endpoint, prompts, gen_tokens, request_id): + """Send a batch of prompts to the /generate endpoint synchronously.""" + sampling_params = { + "max_new_tokens": gen_tokens, + "temperature": 0.7, + "stop": "\n", + } + data = {"text": prompts, "sampling_params": sampling_params} + + start_time = time.time() + try: + response = requests.post( + endpoint.base_url + "/generate", json=data, timeout=3600 + ) + if response.status_code != 200: + error = response.json() + raise RuntimeError(f"Request {request_id} failed: {error}") + result = response.json() + elapsed_time = (time.time() - start_time) * 1000 # Convert to ms + avg_per_prompt = elapsed_time / len(prompts) if prompts else 0 + return request_id, elapsed_time, avg_per_prompt, True, len(prompts) + except Exception as e: + print(f"[Request] Error for request {request_id}: {e}") + return request_id, 0, 0, False, len(prompts) + + +def run_benchmark(endpoint, batched_prompts, batch_size, gen_tokens): + """Run the benchmark sequentially.""" + results = [] + num_requests = len(batched_prompts) + + # Record start time for total latency + benchmark_start_time = time.time() + + for i, batch_prompts in enumerate(batched_prompts): + request_id = i + 1 + assert ( + len(batch_prompts) == batch_size + ), f"Request {request_id} should have {batch_size} prompts, got {len(batch_prompts)}" + + print( + f"[Request] Sending request {request_id}/{num_requests} with {len(batch_prompts)} prompts at {int(time.time()*1000)}" + ) + result = send_batch_request(endpoint, batch_prompts, gen_tokens, request_id) + results.append(result) + + # Calculate total latency + total_latency = (time.time() - benchmark_start_time) * 1000 # Convert to ms + + return results, total_latency + + +############################################################################### +# RESULTS +############################################################################### +def process_results(results, total_latency, num_requests): + """Process and display benchmark results.""" + total_time = 0 + successful_requests = 0 + failed_requests = 0 + request_latencies = [] + per_prompt_latencies = [] + total_prompts = 0 + + for request_id, elapsed_time, avg_per_prompt, success, batch_size in results: + if success: + successful_requests += 1 + total_prompts += batch_size + request_latencies.append(elapsed_time) + per_prompt_latencies.append(avg_per_prompt) + total_time += elapsed_time / 1000 # Convert to seconds + else: + failed_requests += 1 + + avg_request_latency = mean(request_latencies) if request_latencies else 0 + avg_per_prompt_latency = mean(per_prompt_latencies) if per_prompt_latencies else 0 + throughput = total_prompts / total_time if total_time > 0 else 0 + + print("\nBenchmark Summary:") + print(f" Total requests sent: {len(results)}") + print(f" Total prompts sent: {total_prompts}") + print(f" Successful requests: {successful_requests}") + print(f" Failed requests: {failed_requests}") + print(f" Total latency (all requests): {total_latency:.2f} ms") + print(f" Avg per request latency: {avg_request_latency:.2f} ms") + print(f" Avg per prompt latency: {avg_per_prompt_latency:.2f} ms") + print(f" Throughput: {throughput:.2f} prompts/second\n") + + +############################################################################### +# MAIN +############################################################################### +def main(): + # Initialize endpoint + endpoint = RuntimeEndpoint(ENDPOINT_URL) + + # Generate prompts + batched_prompts = prepare_all_prompts( + NUM_REQUESTS, BATCH_SIZE, NUM_TOKENS, TOKENIZER_DIR + ) + + # Flush cache before benchmark + # endpoint.flush_cache() + + # Run benchmark + print( + f"Starting benchmark: NUM_TOKENS={NUM_TOKENS}, BATCH_SIZE={BATCH_SIZE}, NUM_REQUESTS={NUM_REQUESTS}\n" + ) + results, total_latency = run_benchmark( + endpoint, batched_prompts, BATCH_SIZE, GEN_TOKENS + ) + + # Process and display results + process_results(results, total_latency, NUM_REQUESTS) + + +if __name__ == "__main__": + random.seed(0) + main() diff --git a/benchmark/benchmark_batch/benchmark_tokenizer.py b/benchmark/benchmark_batch/benchmark_tokenizer.py new file mode 100644 index 000000000..c00bfb84b --- /dev/null +++ b/benchmark/benchmark_batch/benchmark_tokenizer.py @@ -0,0 +1,126 @@ +import random +import time +from statistics import mean + +from transformers import AutoTokenizer + +# CONFIG +TOKENIZER_DIR = ( + "/shared/public/sharing/fait360brew/training/models/meta-llama/Llama-3.2-3B" +) +NUM_TOKENS = 20000 # Each prompt should contain this many tokens +BATCH_SIZES = [1, 2, 4, 8] # Test different batch sizes +NUM_RUNS = 5 # Number of runs for each batch size to get reliable measurements + + +def generate_random_prompts(num_prompts, num_tokens, tokenizer): + """Generate random prompts with specified token count.""" + vocab_size = tokenizer.vocab_size + all_prompts = [] + + print(f"Generating {num_prompts} random prompts with {num_tokens} tokens each...") + for i in range(num_prompts): + # Generate random token IDs - this directly gives us the exact token count + random_token_ids = [ + random.randint(0, vocab_size - 1) for _ in range(num_tokens) + ] + random_text = tokenizer.decode( + random_token_ids, clean_up_tokenization_spaces=True + ) + + prompt = f"Prompt {i}: {random_text}" + tokens = tokenizer.encode(prompt) + print(f" Prompt {i}: {len(tokens)} tokens") + all_prompts.append(prompt) + + return all_prompts + + +def benchmark_sequential_vs_batch(prompts, batch_size, tokenizer): + """Compare sequential vs batch tokenization for a given batch size.""" + + # Sequential tokenization using encode() + sequential_times = [] + for run in range(NUM_RUNS): + batch_prompts = prompts[:batch_size] # Use same prompts for fair comparison + + start_time = time.time() + for prompt in batch_prompts: + tokens = tokenizer.encode(prompt) + sequential_time = (time.time() - start_time) * 1000 + sequential_times.append(sequential_time) + + # Batch tokenization using tokenizer() + batch_times = [] + for run in range(NUM_RUNS): + batch_prompts = prompts[:batch_size] # Use same prompts for fair comparison + + start_time = time.time() + tokens = tokenizer(batch_prompts) + batch_time = (time.time() - start_time) * 1000 + batch_times.append(batch_time) + + return { + "batch_size": batch_size, + "avg_sequential_ms": mean(sequential_times), + "avg_batch_ms": mean(batch_times), + "speedup_factor": ( + mean(sequential_times) / mean(batch_times) if mean(batch_times) > 0 else 0 + ), + "sequential_runs": sequential_times, + "batch_runs": batch_times, + } + + +def main(): + print("Tokenizer Benchmark: Sequential vs Batch Processing") + print("-" * 60) + print(f"Tokenizer: {TOKENIZER_DIR}") + print(f"Tokens per prompt: {NUM_TOKENS}") + print(f"Number of runs per batch size: {NUM_RUNS}") + print("-" * 60) + + # Load tokenizer once for all operations + tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_DIR) + + # The largest batch size determines how many prompts we need + max_batch_size = max(BATCH_SIZES) + all_prompts = generate_random_prompts(max_batch_size, NUM_TOKENS, tokenizer) + + results = [] + print("\nRunning benchmark...") + + for batch_size in BATCH_SIZES: + print(f"\nBenchmarking batch size: {batch_size}") + result = benchmark_sequential_vs_batch(all_prompts, batch_size, tokenizer) + results.append(result) + + print(f" Sequential tokenization (encode):") + for i, run_time in enumerate(result["sequential_runs"]): + print(f" Run {i+1}: {run_time:.2f} ms") + print(f" Average: {result['avg_sequential_ms']:.2f} ms") + + print(f" Batch tokenization (tokenizer):") + for i, run_time in enumerate(result["batch_runs"]): + print(f" Run {i+1}: {run_time:.2f} ms") + print(f" Average: {result['avg_batch_ms']:.2f} ms") + + print(f" Speedup factor: {result['speedup_factor']:.2f}x") + + print("\n" + "=" * 60) + print("SUMMARY OF RESULTS") + print("=" * 60) + print( + f"{'Batch Size':<10} {'Sequential (ms)':<18} {'Batch (ms)':<18} {'Speedup':<10}" + ) + print("-" * 60) + + for result in results: + print( + f"{result['batch_size']:<10} {result['avg_sequential_ms']:.2f} ms{' ' * 8} {result['avg_batch_ms']:.2f} ms{' ' * 8} {result['speedup_factor']:.2f}x" + ) + + +if __name__ == "__main__": + random.seed(0) + main() diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index a391dd719..92a6bbafc 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -415,6 +415,51 @@ class TokenizerManager: ) if image_inputs and "input_ids" in image_inputs: input_ids = image_inputs["input_ids"] + + self._validate_token_len(obj, input_ids) + return self._create_tokenized_object( + obj, input_text, input_ids, input_embeds, image_inputs + ) + + def _validate_token_len( + self, obj: Union[GenerateReqInput, EmbeddingReqInput], input_ids: List[int] + ) -> None: + """Validates that the input token count and the requested token count doesn't exceed the model's context length.""" + + input_token_num = len(input_ids) if input_ids is not None else 0 + # Check if input alone exceeds context length + if input_token_num >= self.context_len: + raise ValueError( + f"The input ({input_token_num} tokens) is longer than the " + f"model's context length ({self.context_len} tokens)." + ) + + # Check total tokens (input + max_new_tokens) + max_new_tokens = obj.sampling_params.get("max_new_tokens") + if ( + max_new_tokens is not None + and (max_new_tokens + input_token_num) >= self.context_len + ): + total_tokens = max_new_tokens + input_token_num + error_msg = ( + f"Requested token count exceeds the model's maximum context length " + f"of {self.context_len} tokens. You requested a total of {total_tokens} " + f"tokens: {input_token_num} tokens from the input messages and " + f"{max_new_tokens} tokens for the completion. Please reduce the number " + f"of tokens in the input messages or the completion to fit within the limit." + ) + raise ValueError(error_msg) + + def _create_tokenized_object( + self, + obj: Union[GenerateReqInput, EmbeddingReqInput], + input_text: str, + input_ids: List[int], + input_embeds: Optional[Union[List[float], None]] = None, + image_inputs: Optional[Dict] = None, + ) -> Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput]: + """Create a tokenized request object from common parameters.""" + if self.is_generation: return_logprob = obj.return_logprob logprob_start_len = obj.logprob_start_len @@ -424,29 +469,6 @@ class TokenizerManager: SessionParams(**obj.session_params) if obj.session_params else None ) - input_token_num = len(input_ids) if input_ids is not None else 0 - if input_token_num >= self.context_len: - raise ValueError( - f"The input ({input_token_num} tokens) is longer than the " - f"model's context length ({self.context_len} tokens)." - ) - - if ( - obj.sampling_params.get("max_new_tokens") is not None - and obj.sampling_params.get("max_new_tokens") + input_token_num - >= self.context_len - ): - raise ValueError( - f"Requested token count exceeds the model's maximum context length " - f"of {self.context_len} tokens. You requested a total of " - f"{obj.sampling_params.get('max_new_tokens') + input_token_num} " - f"tokens: {input_token_num} tokens from the input messages and " - f"{obj.sampling_params.get('max_new_tokens')} tokens for the " - f"completion. Please reduce the number of tokens in the input " - f"messages or the completion to fit within the limit." - ) - - # Parse sampling parameters sampling_params = SamplingParams(**obj.sampling_params) sampling_params.normalize(self.tokenizer) sampling_params.verify() @@ -483,6 +505,50 @@ class TokenizerManager: return tokenized_obj + async def _batch_tokenize_and_process( + self, batch_size: int, obj: Union[GenerateReqInput, EmbeddingReqInput] + ) -> List[Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput]]: + """Handle batch tokenization for text inputs only.""" + logger.debug(f"Starting batch tokenization for {batch_size} text requests") + + # Collect requests and texts + requests = [obj[i] for i in range(batch_size)] + texts = [req.text for req in requests] + + # Batch tokenize all texts + encoded = self.tokenizer(texts) + input_ids_list = encoded["input_ids"] + + # Process all requests + tokenized_objs = [] + for i, req in enumerate(requests): + self._validate_token_len(obj[i], input_ids_list[i]) + tokenized_objs.append( + self._create_tokenized_object( + req, req.text, input_ids_list[i], None, None + ) + ) + logger.debug(f"Completed batch processing for {batch_size} requests") + return tokenized_objs + + def _validate_batch_tokenization_constraints( + self, batch_size: int, obj: Union[GenerateReqInput, EmbeddingReqInput] + ) -> None: + """Validate constraints for batch tokenization processing.""" + for i in range(batch_size): + if self.is_generation and obj[i].image_data: + raise ValueError( + "For image input processing do not set `enable_tokenizer_batch_encode`." + ) + if obj[i].input_ids is not None: + raise ValueError( + "Batch tokenization is not needed for pre-tokenized input_ids. Do not set `enable_tokenizer_batch_encode`." + ) + if obj[i].input_embeds is not None: + raise ValueError( + "Batch tokenization is not needed for input_embeds. Do not set `enable_tokenizer_batch_encode`." + ) + def _send_one_request( self, obj: Union[GenerateReqInput, EmbeddingReqInput], @@ -560,14 +626,27 @@ class TokenizerManager: generators = [] rids = [] + if getattr(obj, "parallel_sample_num", 1) == 1: - # Send all requests - for i in range(batch_size): - tmp_obj = obj[i] - tokenized_obj = await self._tokenize_one_request(tmp_obj) - self._send_one_request(tmp_obj, tokenized_obj, created_time) - generators.append(self._wait_one_response(tmp_obj, request)) - rids.append(tmp_obj.rid) + if self.server_args.enable_tokenizer_batch_encode: + # Validate batch tokenization constraints + self._validate_batch_tokenization_constraints(batch_size, obj) + + tokenized_objs = await self._batch_tokenize_and_process(batch_size, obj) + + for i, tokenized_obj in enumerate(tokenized_objs): + tmp_obj = obj[i] + self._send_one_request(tmp_obj, tokenized_obj, created_time) + generators.append(self._wait_one_response(tmp_obj, request)) + rids.append(tmp_obj.rid) + else: + # Sequential tokenization and processing + for i in range(batch_size): + tmp_obj = obj[i] + tokenized_obj = await self._tokenize_one_request(tmp_obj) + self._send_one_request(tmp_obj, tokenized_obj, created_time) + generators.append(self._wait_one_response(tmp_obj, request)) + rids.append(tmp_obj.rid) else: # FIXME: When using batch and parallel_sample_num together, the perf is not optimal. if batch_size > 128: diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index e1768b52e..ddbbdf35d 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -49,6 +49,7 @@ class ServerArgs: tokenizer_path: Optional[str] = None tokenizer_mode: str = "auto" skip_tokenizer_init: bool = False + enable_tokenizer_batch_encode: bool = False load_format: str = "auto" trust_remote_code: bool = False dtype: str = "auto" @@ -432,6 +433,11 @@ class ServerArgs: action="store_true", help="If set, skip init tokenizer and pass input_ids in generate request", ) + parser.add_argument( + "--enable-tokenizer-batch-encode", + action="store_true", + help="Enable batch tokenization for improved performance when processing multiple text inputs. Do not use with image inputs, pre-tokenized input_ids, or input_embeds.", + ) parser.add_argument( "--load-format", type=str,