diff --git a/benchmark/score/bench_score.py b/benchmark/score/bench_score.py new file mode 100644 index 000000000..60bcea24c --- /dev/null +++ b/benchmark/score/bench_score.py @@ -0,0 +1,603 @@ +""" +SGLang Scoring Benchmark Script + +This script benchmarks SGLang's scoring API performance using HTTP requests. + +Current Features: +- HTTP-only implementation (open source compatible) +- Uses /v1/score API endpoint directly +- Single item scoring with batching support +- Configurable RPS, duration, and batch sizes +- Progress tracking and detailed metrics +- Poisson and constant request distributions + +Usage: +- Update configuration variables at the top of the file +- Ensure SGLang server is running on the configured HTTP_URL +- Run: python bench_score.py +- Each request will contain ITEM_COUNT_VALUES items for batch scoring + +""" + +import asyncio +import concurrent.futures # For parallel prompt generation +import json +import os +import random +from statistics import mean + +import aiohttp +import numpy as np +from tqdm import tqdm +from transformers import AutoTokenizer + +############################################################################### +# CONFIG +############################################################################### +# Server Configuration +SERVER_TYPE = "HTTP" # Fixed to HTTP for open source + +# HTTP Configuration +HTTP_URL = "http://localhost:30000/v1/score" # Use score API directly + +# Score API Config +# ITEM_COUNT_VALUES determines number of items per score request (batch size) +SCORE_QUERY_TOKENS = 120 +SCORE_ITEM_TOKENS = 180 +SCORE_MODEL_PATH = "Qwen/Qwen3-0.6B" +SCORE_LABEL_TOKEN_IDS = [9454, 2753] # Yes/No token IDs + +# Array of RPS values to test +RPS_VALUES = [70] +# Array of duration values to test +DURATION_SECS_VALUES = [60] # Duration values in seconds +# Array of item count values to test +ITEM_COUNT_VALUES = [10] # Number of items per request +# Number of unique requests to generate (will be reused) +NUM_UNIQUE_REQUESTS = 100 +DISTRIBUTION = "POISSON" # Options: "CONSTANT", "POISSON" + +# Profiling Configuration +PROFILE = False # Enable profiling with START_PROFILE/STOP_PROFILE prompts +# Directory for profiler output +SGLANG_TORCH_PROFILER_DIR = "/shared/user/sglang-oss-trace/remove-decode" +if PROFILE: + os.environ["SGLANG_TORCH_PROFILER_DIR"] = SGLANG_TORCH_PROFILER_DIR + +# Special token to replicate for precise token counting +SPECIAL_REPLICATED_TOKEN = "<|im_start|>" + + +############################################################################### +# REQUEST GENERATION (in parallel) +############################################################################### +def prepare_all_requests_parallel(num_requests, item_count): + """ + Generates unique requests in parallel, then reuses them to create the + full request list. Returns a list of str prompts for HTTP. + """ + # Load tokenizer once here to verify special token and get precise counts + print("Loading tokenizer...") + tokenizer = AutoTokenizer.from_pretrained(SCORE_MODEL_PATH) + + # Verify that our special token produces exactly 1 token + special_token_count = len( + tokenizer.encode(SPECIAL_REPLICATED_TOKEN, add_special_tokens=False) + ) + print( + f"Special token '{SPECIAL_REPLICATED_TOKEN}' produces " + f"{special_token_count} token(s)" + ) + + def generate_text_with_token_count(num_toks): + """Generate text with precise token count using replicated token.""" + if special_token_count == 1: + # Simple case: token maps to exactly 1 token + return SPECIAL_REPLICATED_TOKEN * num_toks + else: + print( + f"Special token '{SPECIAL_REPLICATED_TOKEN}' produces more than 1 token!!!" + ) + # Handle case where special token produces multiple tokens + # Repeat the token enough times to get at least num_toks tokens + repetitions = (num_toks + special_token_count - 1) // special_token_count + text = SPECIAL_REPLICATED_TOKEN * repetitions + + # Verify we got the expected token count (approximately) + actual_tokens = len(tokenizer.encode(text, add_special_tokens=False)) + if actual_tokens < num_toks: + print( + f"Warning: Generated {actual_tokens} tokens, " + f"expected {num_toks}" + ) + + return text + + def build_request(index): + """Build a single request using the shared tokenizer.""" + try: + # Generate query and items for score API + query = generate_text_with_token_count(SCORE_QUERY_TOKENS) + items = [ + generate_text_with_token_count(SCORE_ITEM_TOKENS) + for _ in range(item_count) + ] + + # Return as dict for score API format + score_data = { + "query": query, + "items": items, + "label_token_ids": SCORE_LABEL_TOKEN_IDS, + "model": SCORE_MODEL_PATH, + } + return (index, score_data) + + except Exception as e: + print(f"Error building request {index}: {e}") + return (index, None) + + # Generate only the unique requests + unique_requests = [None] * NUM_UNIQUE_REQUESTS + + # Use ThreadPoolExecutor instead of ProcessPoolExecutor to avoid + # tokenizer loading issues across processes + max_workers = min(8, os.cpu_count() or 1) # Limit to 8 threads max + + with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: + futures = [] + for i in tqdm( + range(NUM_UNIQUE_REQUESTS), desc="Submitting prompt generation tasks" + ): + future = executor.submit(build_request, i) + futures.append(future) + + # Collect results as they complete + for f in tqdm( + concurrent.futures.as_completed(futures), + desc="Building unique requests", + total=NUM_UNIQUE_REQUESTS, + ): + try: + index, req_data = f.result() + if req_data is not None: + unique_requests[index] = req_data + else: + print(f"Failed to build request {index}") + except Exception as e: + print(f"Error processing request result: {e}") + + # Check if we have any valid requests + valid_requests = [req for req in unique_requests if req is not None] + if not valid_requests: + raise RuntimeError("Failed to generate any valid requests") + + print( + f"Successfully generated {len(valid_requests)} out of " + f"{NUM_UNIQUE_REQUESTS} unique requests" + ) + + # Create the full request list by cycling through unique requests + print( + f"Reusing {len(valid_requests)} unique requests to create " + f"{num_requests} total requests..." + ) + all_requests = [] + for i in tqdm(range(num_requests), desc="Reusing requests"): + unique_index = i % len(valid_requests) + all_requests.append(valid_requests[unique_index]) + + print("All prompts/requests prepared.\n") + return all_requests + + +############################################################################### +# PROFILING HELPERS +############################################################################### +async def send_profile_request(profile_text, item_count, session=None): + """Send a profile request and wait for completion.""" + try: + if session: + print(f"Sending {profile_text} request via HTTP...") + + # Determine the correct endpoint + base_url = HTTP_URL.rsplit("/", 2)[0] # Remove /v1/score + if profile_text == "START_PROFILE": + endpoint_url = f"{base_url}/start_profile" + elif profile_text == "STOP_PROFILE": + endpoint_url = f"{base_url}/stop_profile" + else: + print(f"Unknown profile request: {profile_text}") + return + + headers = {"Content-Type": "application/json"} + + async with session.post(endpoint_url, headers=headers) as resp: + resp_text = await resp.text() + if resp.status == 200: + print(f"{profile_text} request completed") + else: + print( + f"{profile_text} request failed with status " + f"{resp.status}: {resp_text}" + ) + else: + print(f"Cannot send {profile_text} request - missing session") + + except Exception as e: + print(f"Error sending {profile_text} request: {e}") + + +############################################################################### +# HTTP CALLS +############################################################################### +def build_http_request_json(score_data): + """Build HTTP request JSON for /v1/score endpoint. + + Score API format: + { + "query": "Generated query text with SCORE_QUERY_TOKENS tokens", + "items": ["item1", "item2", ...], # Items to score with SCORE_ITEM_TOKENS each + "label_token_ids": [token_id1, token_id2], # Target token IDs + "model": "/path/to/model" + } + + Args: + score_data: A dict containing query, items, label_token_ids, and model + """ + # score_data is already in the correct format from build_request + return json.dumps(score_data) + + +async def make_http_call(session, score_data, request_id, results_queue): + """HTTP call to /v1/score endpoint.""" + try: + start_time = asyncio.get_event_loop().time() + + request_json = build_http_request_json(score_data) + headers = {"Content-Type": "application/json"} + + async with session.post(HTTP_URL, data=request_json, headers=headers) as resp: + resp_text = await resp.text() + + if resp.status != 200: + print( + f"[HTTP] Request {request_id} failed with status " + f"{resp.status}: {resp_text}" + ) + completion_time = asyncio.get_event_loop().time() + await results_queue.put((request_id, 0, False, completion_time)) + return + + # Parse score API response + try: + response_data = json.loads(resp_text) + # Score API returns scores for each item + # For now, just verify we got a valid response + if "scores" in response_data or "logprobs" in response_data: + success = True + else: + print( + f"[HTTP] Request {request_id} missing expected fields in response" + ) + success = False + except json.JSONDecodeError: + print(f"[HTTP] Request {request_id} failed to parse JSON response") + success = False + + completion_time = asyncio.get_event_loop().time() + elapsed_time = (completion_time - start_time) * 1000 + await results_queue.put((request_id, elapsed_time, success, completion_time)) + + except Exception as e: + print(f"[HTTP] Error for request {request_id}: {e}") + completion_time = asyncio.get_event_loop().time() + await results_queue.put((request_id, 0, False, completion_time)) + + +############################################################################### +# RESULTS +############################################################################### +async def process_results( + results_queue, + num_requests, + send_duration, + total_duration, + rps, + duration_secs, + item_count, + test_start_time, +): + """Processes results and groups them by minute intervals. + Returns a list of dictionaries, one for each minute.""" + all_results = [] + + # Collect all results + for _ in range(num_requests): + result = await results_queue.get() + request_id, elapsed_time, success, completion_time = result + all_results.append( + { + "request_id": request_id, + "elapsed_time": elapsed_time, + "success": success, + "completion_time": completion_time, + } + ) + + # Group results by minute intervals + minute_results = [] + num_minutes = int(duration_secs // 60) + (1 if duration_secs % 60 > 0 else 0) + + for minute in range(num_minutes): + minute_start = test_start_time + (minute * 60) + minute_end = test_start_time + ((minute + 1) * 60) + + # Filter results that completed in this minute + minute_data = [ + r for r in all_results if minute_start <= r["completion_time"] < minute_end + ] + + response_times = [r["elapsed_time"] for r in minute_data if r["success"]] + successful_requests = len([r for r in minute_data if r["success"]]) + failed_requests = len([r for r in minute_data if not r["success"]]) + + avg_response_time = mean(response_times) if response_times else 0 + + # Calculate percentiles using numpy + if response_times: + p50 = np.percentile(response_times, 50) + p90 = np.percentile(response_times, 90) + p99 = np.percentile(response_times, 99) + else: + p50 = p90 = p99 = 0 + + minute_result = { + "test_duration_secs": duration_secs, + "minute_interval": minute + 1, + "target_rps": rps, + "item_count": item_count, + "server_type": SERVER_TYPE, + "distribution": DISTRIBUTION, + "unique_requests": NUM_UNIQUE_REQUESTS, + "total_requests": len(minute_data), + "successful_requests": successful_requests, + "failed_requests": failed_requests, + "send_duration_secs": send_duration, + "total_duration_secs": total_duration, + "avg_response_time_ms": avg_response_time, + "p50_response_time_ms": p50, + "p90_response_time_ms": p90, + "p99_response_time_ms": p99, + } + + minute_results.append(minute_result) + + print( + f"\nMinute {minute + 1} Summary for RPS {rps}, " + f"Duration {duration_secs}s, Item Count {item_count}:" + ) + print(f" Requests completed in minute: {len(minute_data)}") + print(f" Successful requests: {successful_requests}") + print(f" Failed requests: {failed_requests}") + print(f" Average response time: {avg_response_time:.2f} ms") + print(f" P50 response time: {p50:.2f} ms") + print(f" P90 response time: {p90:.2f} ms") + print(f" P99 response time: {p99:.2f} ms") + + # Also print overall summary + all_response_times = [r["elapsed_time"] for r in all_results if r["success"]] + total_successful = len([r for r in all_results if r["success"]]) + total_failed = len([r for r in all_results if not r["success"]]) + + overall_avg = mean(all_response_times) if all_response_times else 0 + if all_response_times: + overall_p50 = np.percentile(all_response_times, 50) + overall_p90 = np.percentile(all_response_times, 90) + overall_p99 = np.percentile(all_response_times, 99) + else: + overall_p50 = overall_p90 = overall_p99 = 0 + + print( + f"\nOverall Summary for RPS {rps}, Duration {duration_secs}s, " + f"Item Count {item_count}:" + ) + print(f" Test duration: {duration_secs} seconds") + print(f" Server type: {SERVER_TYPE}") + print(f" HTTP mode: SINGLE_ITEM_SCORING") + print(f" Target RPS: {rps}") + print(f" Item count: {item_count}") + print(f" Distribution: {DISTRIBUTION}") + print(f" Unique requests generated: {NUM_UNIQUE_REQUESTS}") + print(f" Total requests sent: {num_requests}") + print(f" Successful requests: {total_successful}") + print(f" Failed requests: {total_failed}") + print(f" Time to send all requests: {send_duration:.2f} seconds") + print(f" Time for all requests to complete: {total_duration:.2f} seconds") + print(f" Average response time: {overall_avg:.2f} ms") + print(f" P50 response time: {overall_p50:.2f} ms") + print(f" P90 response time: {overall_p90:.2f} ms") + print(f" P99 response time: {overall_p99:.2f} ms\n") + + return minute_results + + +############################################################################### +# MAIN +############################################################################### +async def run_benchmark(rps, duration_secs, item_count): + """Run a single benchmark with the given RPS value.""" + num_requests = int(rps * duration_secs) + print( + f"Starting benchmark with RPS={rps}, Duration={duration_secs}s, " + f"Item Count={item_count}, num_requests={num_requests}" + ) + print(f"Server Type: {SERVER_TYPE}") + print(f"HTTP Mode: SINGLE_ITEM_SCORING") + print(f"Profiling Enabled: {PROFILE}") + + # Build requests in parallel (unmeasured) + all_requests = prepare_all_requests_parallel(num_requests, item_count) + + results_queue = asyncio.Queue() + tasks = [] + + # Track timing for sending requests + send_start_time = asyncio.get_event_loop().time() + + # HTTP implementation (open source only supports HTTP with /v1/score API) + async with aiohttp.ClientSession( + timeout=aiohttp.ClientTimeout(total=300) + ) as session: + + # Send START_PROFILE if profiling is enabled + if PROFILE: + await send_profile_request("START_PROFILE", item_count, session=session) + + # Add progress bar for sending requests + with tqdm( + total=len(all_requests), + desc=f"Sending HTTP score requests at {rps} RPS", + unit="req", + ) as pbar: + for i, score_data in enumerate(all_requests): + request_id = i + 1 + tasks.append( + asyncio.create_task( + make_http_call(session, score_data, request_id, results_queue) + ) + ) + + # Update progress bar + pbar.update(1) + + # Throttle based on distribution + if i < len(all_requests) - 1: + if DISTRIBUTION == "CONSTANT": + interval = 1 / rps + await asyncio.sleep(interval) + elif DISTRIBUTION == "POISSON": + # For Poisson process, inter-arrival times follow + # exponential distribution + interval = random.expovariate(rps) + await asyncio.sleep(interval) + else: + raise ValueError( + f"Unknown distribution: {DISTRIBUTION}. " + f"Use 'CONSTANT' or 'POISSON'." + ) + + send_end_time = asyncio.get_event_loop().time() + send_duration = send_end_time - send_start_time + + # Wait for all requests to complete with progress tracking + print(f"Waiting for {len(tasks)} HTTP score requests to complete...") + with tqdm( + total=len(tasks), desc="Completing HTTP score requests", unit="req" + ) as completion_pbar: + completed_tasks = [] + for task in asyncio.as_completed(tasks): + await task + completed_tasks.append(task) + completion_pbar.update(1) + + # Send STOP_PROFILE if profiling is enabled + if PROFILE: + await send_profile_request("STOP_PROFILE", item_count, session=session) + + completion_end_time = asyncio.get_event_loop().time() + total_duration = completion_end_time - send_start_time + + return await process_results( + results_queue, + num_requests, + send_duration, + total_duration, + rps, + duration_secs, + item_count, + send_start_time, + ) + + +async def main(): + """Main function that runs benchmarks for all RPS values.""" + total_combinations = ( + len(DURATION_SECS_VALUES) * len(RPS_VALUES) * len(ITEM_COUNT_VALUES) + ) + print( + f"Running benchmarks for {len(DURATION_SECS_VALUES)} duration " + f"values, {len(RPS_VALUES)} RPS values, and " + f"{len(ITEM_COUNT_VALUES)} item count values = " + f"{total_combinations} total combinations" + ) + print(f"Server Type: {SERVER_TYPE}") + print(f"HTTP Mode: SINGLE_ITEM_SCORING") + print(f"Score API URL: {HTTP_URL}") + print(f"Query tokens per request: {SCORE_QUERY_TOKENS}") + print(f"Item tokens per item: {SCORE_ITEM_TOKENS}") + print(f"Items per request (batch size): {ITEM_COUNT_VALUES}") + print(f"Profiling Enabled: {PROFILE}") + print(f"Duration values: {DURATION_SECS_VALUES}") + print(f"RPS values: {RPS_VALUES}") + print(f"Item count values: {ITEM_COUNT_VALUES}") + print("=" * 80) + + all_results = [] + + for duration_secs in DURATION_SECS_VALUES: + for rps in RPS_VALUES: + for item_count in ITEM_COUNT_VALUES: + result = await run_benchmark(rps, duration_secs, item_count) + all_results.extend(result) # Extend with minute results + + # Print CSV header and results + print("\n" + "=" * 80) + print("FINAL CSV RESULTS:") + print("=" * 80) + + # CSV Header + headers = [ + "test_duration_secs", + "minute_interval", + "target_rps", + "item_count", + "server_type", + "distribution", + "unique_requests", + "total_requests", + "successful_requests", + "failed_requests", + "send_duration_secs", + "total_duration_secs", + "avg_response_time_ms", + "p50_response_time_ms", + "p90_response_time_ms", + "p99_response_time_ms", + ] + print(",".join(headers)) + + # CSV Data + for result in all_results: + row = [ + result["test_duration_secs"], + result["minute_interval"], + result["target_rps"], + result["item_count"], + result["server_type"], + result["distribution"], + result["unique_requests"], + result["total_requests"], + result["successful_requests"], + result["failed_requests"], + f"{result['send_duration_secs']:.2f}", + f"{result['total_duration_secs']:.2f}", + f"{result['avg_response_time_ms']:.2f}", + f"{result['p50_response_time_ms']:.2f}", + f"{result['p90_response_time_ms']:.2f}", + f"{result['p99_response_time_ms']:.2f}", + ] + print(",".join(map(str, row))) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 7628ec2dd..55b1a9ec7 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -913,6 +913,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): # Whether to return hidden states return_hidden_states: bool = False + # Whether this batch is prefill-only (no token generation needed) + is_prefill_only: bool = False + # hicache pointer for synchronizing data loading from CPU to GPU hicache_consumer_index: int = 0 @@ -953,6 +956,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): device=req_to_token_pool.device, spec_algorithm=spec_algorithm, return_hidden_states=any(req.return_hidden_states for req in reqs), + is_prefill_only=all( + req.sampling_params.max_new_tokens == 0 for req in reqs + ), chunked_req=chunked_req, ) @@ -1796,6 +1802,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): global_num_tokens_for_logprob=self.global_num_tokens_for_logprob, can_run_dp_cuda_graph=self.can_run_dp_cuda_graph, is_extend_in_batch=self.is_extend_in_batch, + is_prefill_only=self.is_prefill_only, ) def _evict_tree_cache_if_needed(self, num_tokens: int): diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index fc0055b2b..95a529c89 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -1466,8 +1466,9 @@ class Scheduler( if self.last_batch.batch_size() < last_bs: self.running_batch.batch_is_full = False - # Merge the new batch into the running batch - if not self.last_batch.is_empty(): + # Merge the new batch into the running batch. + # For prefill-only batch, we can avoid going through decoding step. + if not self.last_batch.is_empty() and not self.last_batch.is_prefill_only: if self.running_batch.is_empty(): self.running_batch = self.last_batch else: diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 00dd6a065..2b9aa8219 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -699,7 +699,7 @@ class TokenizerManager: # Process all requests tokenized_objs = [] for i, req in enumerate(requests): - self._validate_token_len(obj[i], input_ids_list[i]) + self._validate_one_request(obj[i], input_ids_list[i]) tokenized_objs.append( self._create_tokenized_object( req, req.text, input_ids_list[i], None, None @@ -1892,6 +1892,13 @@ class TokenizerManager: f"Token ID {token_id} is out of vocabulary (vocab size: {vocab_size})" ) + batch_request = GenerateReqInput( + token_ids_logprob=label_token_ids, + return_logprob=True, + stream=False, + sampling_params={"max_new_tokens": 0}, + ) + # Handle string or tokenized query/items if isinstance(query, str) and ( isinstance(items, str) @@ -1903,13 +1910,9 @@ class TokenizerManager: prompts = [f"{item}{query}" for item in items_list] else: prompts = [f"{query}{item}" for item in items_list] - batch_request = GenerateReqInput( - text=prompts, - return_logprob=True, - token_ids_logprob=label_token_ids, - stream=False, - sampling_params={"max_new_tokens": 1}, - ) + + batch_request.text = prompts + elif ( isinstance(query, list) and isinstance(items, list) @@ -1921,13 +1924,8 @@ class TokenizerManager: input_ids_list = [item + query for item in items] else: input_ids_list = [query + item for item in items] - batch_request = GenerateReqInput( - input_ids=input_ids_list, - return_logprob=True, - token_ids_logprob=label_token_ids, - stream=False, - sampling_params={"max_new_tokens": 1}, - ) + + batch_request.input_ids = input_ids_list else: raise ValueError( "Invalid combination of query/items types for score_request." @@ -1939,9 +1937,20 @@ class TokenizerManager: for result in results: # Get logprobs for each token logprobs = {} - for logprob, token_id, _ in result["meta_info"].get( - "output_token_ids_logprobs", [] - )[0]: + + # For scoring requests, we read from output_token_ids_logprobs since we want + # the logprobs for specific tokens mentioned in the label_token_ids at + # the next position after the last token in the prompt + output_logprobs = result["meta_info"].get("output_token_ids_logprobs", []) + + # Throw an error here if output_logprobs is None + if output_logprobs is None: + raise RuntimeError( + f"output_logprobs is None for request {result['meta_info'].get('id', '')}. " + "This usually indicates a problem with the scoring request or the backend output." + ) + + for logprob, token_id, _ in output_logprobs[0]: if token_id in label_token_ids: logprobs[token_id] = logprob diff --git a/test/srt/test_score_api.py b/test/srt/test_score_api.py index afd7d00f4..d08ae9df7 100644 --- a/test/srt/test_score_api.py +++ b/test/srt/test_score_api.py @@ -213,6 +213,88 @@ class TestScoreAPI(CustomTestCase): 1.0, sum(score_list), 6, "Scores should sum to 1" ) + def test_score_request_construction(self): + """Test that scoring requests are constructed to avoid decode phase.""" + from unittest.mock import patch + + # Capture the internal request to verify optimization + captured_requests = [] + original_gen = self.engine.tokenizer_manager.generate_request + + async def mock_generate_request(req, request=None): + captured_requests.append(req) + async for result in original_gen(req, request): + yield result + + # Patch the generate_request method + with patch.object( + self.engine.tokenizer_manager, + "generate_request", + side_effect=mock_generate_request, + ): + # Run a scoring request + query = "What is the capital of" + items = ["France", "Germany"] + label_token_ids = [1, 2, 3] + + scores = self.engine.score( + query=query, + items=items, + label_token_ids=label_token_ids, + apply_softmax=True, + ) + + # Verify we got results + self.assertEqual(len(scores), len(items)) + + # Verify the captured request has decode-avoiding properties + self.assertEqual(len(captured_requests), 1) + request = captured_requests[0] + + # Key assertions for decode phase avoidance: + # 1. max_new_tokens should be 0 (prevents token generation) + # Handle both single and batch request cases + if isinstance(request.sampling_params, dict): + max_new_tokens = request.sampling_params.get("max_new_tokens", 0) + elif isinstance(request.sampling_params, list): + # For batch requests, check the first item + max_new_tokens = request.sampling_params[0].get("max_new_tokens", 0) + else: + max_new_tokens = getattr(request.sampling_params, "max_new_tokens", 0) + + self.assertEqual( + max_new_tokens, 0, "max_new_tokens should be 0 to avoid decode phase" + ) + + # 2. Should have token_ids_logprob for scoring + # Handle both single and batch request cases + if ( + isinstance(request.token_ids_logprob, list) + and len(request.token_ids_logprob) > 0 + and isinstance(request.token_ids_logprob[0], list) + ): + # Batch case: token_ids_logprob is a list of lists + # Each item in the batch should have the same label_token_ids + for item_token_ids in request.token_ids_logprob: + self.assertEqual( + item_token_ids, + label_token_ids, + "Each batch item should have label_token_ids for scoring", + ) + else: + # Single request case + self.assertEqual( + request.token_ids_logprob, + label_token_ids, + "Should have label_token_ids for scoring", + ) + + # 3. Should request logprobs but not stream + self.assertTrue( + request.return_logprob, "Should request logprobs for scoring" + ) + self.assertFalse(request.stream, "Scoring requests should not stream") + if __name__ == "__main__": unittest.main() diff --git a/test/srt/test_tokenizer_batch_encode.py b/test/srt/test_tokenizer_batch_encode.py new file mode 100644 index 000000000..f3294c049 --- /dev/null +++ b/test/srt/test_tokenizer_batch_encode.py @@ -0,0 +1,121 @@ +""" +Unit tests for enable_tokenizer_batch_encode feature. + +This tests the batch tokenization functionality which allows processing +multiple text inputs in a single batch for improved performance. + +Usage: +python3 -m unittest test_tokenizer_batch_encode.TestTokenizerBatchEncode.test_batch_validation_constraints +python3 -m unittest test_tokenizer_batch_encode.TestTokenizerBatchEncodeUnit.test_batch_tokenize_and_process_logic +python3 -m unittest test_tokenizer_batch_encode.TestTokenizerBatchEncodeLogic.test_batch_processing_path +""" + +import asyncio +import unittest +from typing import List +from unittest.mock import AsyncMock, Mock, call, patch + +from sglang.srt.managers.io_struct import GenerateReqInput, TokenizedGenerateReqInput +from sglang.srt.managers.tokenizer_manager import TokenizerManager +from sglang.srt.server_args import PortArgs, ServerArgs +from sglang.test.test_utils import DEFAULT_SMALL_MODEL_NAME_FOR_TEST + + +class TestTokenizerBatchEncode(unittest.TestCase): + """Test cases for tokenizer batch encoding validation and setup.""" + + def setUp(self): + """Set up test fixtures.""" + self.server_args = ServerArgs( + model_path=DEFAULT_SMALL_MODEL_NAME_FOR_TEST, + enable_tokenizer_batch_encode=True, + ) + self.port_args = PortArgs.init_new(self.server_args) + + with patch("zmq.asyncio.Context"), patch( + "sglang.srt.utils.get_zmq_socket" + ), patch("sglang.srt.hf_transformers_utils.get_tokenizer") as mock_tokenizer: + + mock_tokenizer.return_value = Mock(vocab_size=32000) + self.tokenizer_manager = TokenizerManager(self.server_args, self.port_args) + + def test_batch_encode_enabled(self): + """Test that batch encoding is enabled when configured.""" + self.assertTrue(self.server_args.enable_tokenizer_batch_encode) + + def test_batch_encode_disabled(self): + """Test that batch encoding can be disabled.""" + server_args_disabled = ServerArgs( + model_path=DEFAULT_SMALL_MODEL_NAME_FOR_TEST, + enable_tokenizer_batch_encode=False, + ) + self.assertFalse(server_args_disabled.enable_tokenizer_batch_encode) + + def test_multimodal_input_validation(self): + """Test that multimodal inputs are rejected in batch mode.""" + req = GenerateReqInput(text="test", image_data=["dummy"]) + req.contains_mm_input = Mock(return_value=True) + + batch_obj = Mock() + batch_obj.__getitem__ = lambda self, i: req + + self.tokenizer_manager.is_generation = True + + with self.assertRaises(ValueError) as cm: + self.tokenizer_manager._validate_batch_tokenization_constraints( + 1, batch_obj + ) + + self.assertIn("multimodal", str(cm.exception)) + + def test_pretokenized_input_validation(self): + """Test that pre-tokenized inputs are rejected in batch mode.""" + req = GenerateReqInput(input_ids=[1, 2, 3]) + + batch_obj = Mock() + batch_obj.__getitem__ = lambda self, i: req + + with self.assertRaises(ValueError) as cm: + self.tokenizer_manager._validate_batch_tokenization_constraints( + 1, batch_obj + ) + + self.assertIn("pre-tokenized", str(cm.exception)) + + def test_input_embeds_validation(self): + """Test that input embeds are rejected in batch mode.""" + req = GenerateReqInput(input_embeds=[0.1, 0.2]) + + batch_obj = Mock() + batch_obj.__getitem__ = lambda self, i: req + + with self.assertRaises(ValueError) as cm: + self.tokenizer_manager._validate_batch_tokenization_constraints( + 1, batch_obj + ) + + self.assertIn("input_embeds", str(cm.exception)) + + def test_valid_text_only_requests_pass_validation(self): + """Test that valid text-only requests pass validation.""" + # Create valid requests (text-only) + requests = [] + for i in range(3): + req = GenerateReqInput(text=f"test text {i}") + req.contains_mm_input = Mock(return_value=False) + requests.append(req) + + batch_obj = Mock() + batch_obj.__getitem__ = Mock(side_effect=lambda i: requests[i]) + + # Should not raise any exception + try: + self.tokenizer_manager._validate_batch_tokenization_constraints( + 3, batch_obj + ) + except Exception as e: + self.fail(f"Validation failed for valid text-only requests: {e}") + + +if __name__ == "__main__": + unittest.main(verbosity=2)