[Generative Score API] Optimization to Remove Decode. (#8840)
This commit is contained in:
committed by
GitHub
parent
9e426466af
commit
a027a9b4b3
603
benchmark/score/bench_score.py
Normal file
603
benchmark/score/bench_score.py
Normal file
@@ -0,0 +1,603 @@
|
|||||||
|
"""
|
||||||
|
SGLang Scoring Benchmark Script
|
||||||
|
|
||||||
|
This script benchmarks SGLang's scoring API performance using HTTP requests.
|
||||||
|
|
||||||
|
Current Features:
|
||||||
|
- HTTP-only implementation (open source compatible)
|
||||||
|
- Uses /v1/score API endpoint directly
|
||||||
|
- Single item scoring with batching support
|
||||||
|
- Configurable RPS, duration, and batch sizes
|
||||||
|
- Progress tracking and detailed metrics
|
||||||
|
- Poisson and constant request distributions
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
- Update configuration variables at the top of the file
|
||||||
|
- Ensure SGLang server is running on the configured HTTP_URL
|
||||||
|
- Run: python bench_score.py
|
||||||
|
- Each request will contain ITEM_COUNT_VALUES items for batch scoring
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import concurrent.futures # For parallel prompt generation
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
from statistics import mean
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
import numpy as np
|
||||||
|
from tqdm import tqdm
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
|
###############################################################################
|
||||||
|
# CONFIG
|
||||||
|
###############################################################################
|
||||||
|
# Server Configuration
|
||||||
|
SERVER_TYPE = "HTTP" # Fixed to HTTP for open source
|
||||||
|
|
||||||
|
# HTTP Configuration
|
||||||
|
HTTP_URL = "http://localhost:30000/v1/score" # Use score API directly
|
||||||
|
|
||||||
|
# Score API Config
|
||||||
|
# ITEM_COUNT_VALUES determines number of items per score request (batch size)
|
||||||
|
SCORE_QUERY_TOKENS = 120
|
||||||
|
SCORE_ITEM_TOKENS = 180
|
||||||
|
SCORE_MODEL_PATH = "Qwen/Qwen3-0.6B"
|
||||||
|
SCORE_LABEL_TOKEN_IDS = [9454, 2753] # Yes/No token IDs
|
||||||
|
|
||||||
|
# Array of RPS values to test
|
||||||
|
RPS_VALUES = [70]
|
||||||
|
# Array of duration values to test
|
||||||
|
DURATION_SECS_VALUES = [60] # Duration values in seconds
|
||||||
|
# Array of item count values to test
|
||||||
|
ITEM_COUNT_VALUES = [10] # Number of items per request
|
||||||
|
# Number of unique requests to generate (will be reused)
|
||||||
|
NUM_UNIQUE_REQUESTS = 100
|
||||||
|
DISTRIBUTION = "POISSON" # Options: "CONSTANT", "POISSON"
|
||||||
|
|
||||||
|
# Profiling Configuration
|
||||||
|
PROFILE = False # Enable profiling with START_PROFILE/STOP_PROFILE prompts
|
||||||
|
# Directory for profiler output
|
||||||
|
SGLANG_TORCH_PROFILER_DIR = "/shared/user/sglang-oss-trace/remove-decode"
|
||||||
|
if PROFILE:
|
||||||
|
os.environ["SGLANG_TORCH_PROFILER_DIR"] = SGLANG_TORCH_PROFILER_DIR
|
||||||
|
|
||||||
|
# Special token to replicate for precise token counting
|
||||||
|
SPECIAL_REPLICATED_TOKEN = "<|im_start|>"
|
||||||
|
|
||||||
|
|
||||||
|
###############################################################################
|
||||||
|
# REQUEST GENERATION (in parallel)
|
||||||
|
###############################################################################
|
||||||
|
def prepare_all_requests_parallel(num_requests, item_count):
|
||||||
|
"""
|
||||||
|
Generates unique requests in parallel, then reuses them to create the
|
||||||
|
full request list. Returns a list of str prompts for HTTP.
|
||||||
|
"""
|
||||||
|
# Load tokenizer once here to verify special token and get precise counts
|
||||||
|
print("Loading tokenizer...")
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(SCORE_MODEL_PATH)
|
||||||
|
|
||||||
|
# Verify that our special token produces exactly 1 token
|
||||||
|
special_token_count = len(
|
||||||
|
tokenizer.encode(SPECIAL_REPLICATED_TOKEN, add_special_tokens=False)
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f"Special token '{SPECIAL_REPLICATED_TOKEN}' produces "
|
||||||
|
f"{special_token_count} token(s)"
|
||||||
|
)
|
||||||
|
|
||||||
|
def generate_text_with_token_count(num_toks):
|
||||||
|
"""Generate text with precise token count using replicated token."""
|
||||||
|
if special_token_count == 1:
|
||||||
|
# Simple case: token maps to exactly 1 token
|
||||||
|
return SPECIAL_REPLICATED_TOKEN * num_toks
|
||||||
|
else:
|
||||||
|
print(
|
||||||
|
f"Special token '{SPECIAL_REPLICATED_TOKEN}' produces more than 1 token!!!"
|
||||||
|
)
|
||||||
|
# Handle case where special token produces multiple tokens
|
||||||
|
# Repeat the token enough times to get at least num_toks tokens
|
||||||
|
repetitions = (num_toks + special_token_count - 1) // special_token_count
|
||||||
|
text = SPECIAL_REPLICATED_TOKEN * repetitions
|
||||||
|
|
||||||
|
# Verify we got the expected token count (approximately)
|
||||||
|
actual_tokens = len(tokenizer.encode(text, add_special_tokens=False))
|
||||||
|
if actual_tokens < num_toks:
|
||||||
|
print(
|
||||||
|
f"Warning: Generated {actual_tokens} tokens, "
|
||||||
|
f"expected {num_toks}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return text
|
||||||
|
|
||||||
|
def build_request(index):
|
||||||
|
"""Build a single request using the shared tokenizer."""
|
||||||
|
try:
|
||||||
|
# Generate query and items for score API
|
||||||
|
query = generate_text_with_token_count(SCORE_QUERY_TOKENS)
|
||||||
|
items = [
|
||||||
|
generate_text_with_token_count(SCORE_ITEM_TOKENS)
|
||||||
|
for _ in range(item_count)
|
||||||
|
]
|
||||||
|
|
||||||
|
# Return as dict for score API format
|
||||||
|
score_data = {
|
||||||
|
"query": query,
|
||||||
|
"items": items,
|
||||||
|
"label_token_ids": SCORE_LABEL_TOKEN_IDS,
|
||||||
|
"model": SCORE_MODEL_PATH,
|
||||||
|
}
|
||||||
|
return (index, score_data)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error building request {index}: {e}")
|
||||||
|
return (index, None)
|
||||||
|
|
||||||
|
# Generate only the unique requests
|
||||||
|
unique_requests = [None] * NUM_UNIQUE_REQUESTS
|
||||||
|
|
||||||
|
# Use ThreadPoolExecutor instead of ProcessPoolExecutor to avoid
|
||||||
|
# tokenizer loading issues across processes
|
||||||
|
max_workers = min(8, os.cpu_count() or 1) # Limit to 8 threads max
|
||||||
|
|
||||||
|
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||||
|
futures = []
|
||||||
|
for i in tqdm(
|
||||||
|
range(NUM_UNIQUE_REQUESTS), desc="Submitting prompt generation tasks"
|
||||||
|
):
|
||||||
|
future = executor.submit(build_request, i)
|
||||||
|
futures.append(future)
|
||||||
|
|
||||||
|
# Collect results as they complete
|
||||||
|
for f in tqdm(
|
||||||
|
concurrent.futures.as_completed(futures),
|
||||||
|
desc="Building unique requests",
|
||||||
|
total=NUM_UNIQUE_REQUESTS,
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
index, req_data = f.result()
|
||||||
|
if req_data is not None:
|
||||||
|
unique_requests[index] = req_data
|
||||||
|
else:
|
||||||
|
print(f"Failed to build request {index}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error processing request result: {e}")
|
||||||
|
|
||||||
|
# Check if we have any valid requests
|
||||||
|
valid_requests = [req for req in unique_requests if req is not None]
|
||||||
|
if not valid_requests:
|
||||||
|
raise RuntimeError("Failed to generate any valid requests")
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"Successfully generated {len(valid_requests)} out of "
|
||||||
|
f"{NUM_UNIQUE_REQUESTS} unique requests"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create the full request list by cycling through unique requests
|
||||||
|
print(
|
||||||
|
f"Reusing {len(valid_requests)} unique requests to create "
|
||||||
|
f"{num_requests} total requests..."
|
||||||
|
)
|
||||||
|
all_requests = []
|
||||||
|
for i in tqdm(range(num_requests), desc="Reusing requests"):
|
||||||
|
unique_index = i % len(valid_requests)
|
||||||
|
all_requests.append(valid_requests[unique_index])
|
||||||
|
|
||||||
|
print("All prompts/requests prepared.\n")
|
||||||
|
return all_requests
|
||||||
|
|
||||||
|
|
||||||
|
###############################################################################
|
||||||
|
# PROFILING HELPERS
|
||||||
|
###############################################################################
|
||||||
|
async def send_profile_request(profile_text, item_count, session=None):
|
||||||
|
"""Send a profile request and wait for completion."""
|
||||||
|
try:
|
||||||
|
if session:
|
||||||
|
print(f"Sending {profile_text} request via HTTP...")
|
||||||
|
|
||||||
|
# Determine the correct endpoint
|
||||||
|
base_url = HTTP_URL.rsplit("/", 2)[0] # Remove /v1/score
|
||||||
|
if profile_text == "START_PROFILE":
|
||||||
|
endpoint_url = f"{base_url}/start_profile"
|
||||||
|
elif profile_text == "STOP_PROFILE":
|
||||||
|
endpoint_url = f"{base_url}/stop_profile"
|
||||||
|
else:
|
||||||
|
print(f"Unknown profile request: {profile_text}")
|
||||||
|
return
|
||||||
|
|
||||||
|
headers = {"Content-Type": "application/json"}
|
||||||
|
|
||||||
|
async with session.post(endpoint_url, headers=headers) as resp:
|
||||||
|
resp_text = await resp.text()
|
||||||
|
if resp.status == 200:
|
||||||
|
print(f"{profile_text} request completed")
|
||||||
|
else:
|
||||||
|
print(
|
||||||
|
f"{profile_text} request failed with status "
|
||||||
|
f"{resp.status}: {resp_text}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
print(f"Cannot send {profile_text} request - missing session")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error sending {profile_text} request: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
###############################################################################
|
||||||
|
# HTTP CALLS
|
||||||
|
###############################################################################
|
||||||
|
def build_http_request_json(score_data):
|
||||||
|
"""Build HTTP request JSON for /v1/score endpoint.
|
||||||
|
|
||||||
|
Score API format:
|
||||||
|
{
|
||||||
|
"query": "Generated query text with SCORE_QUERY_TOKENS tokens",
|
||||||
|
"items": ["item1", "item2", ...], # Items to score with SCORE_ITEM_TOKENS each
|
||||||
|
"label_token_ids": [token_id1, token_id2], # Target token IDs
|
||||||
|
"model": "/path/to/model"
|
||||||
|
}
|
||||||
|
|
||||||
|
Args:
|
||||||
|
score_data: A dict containing query, items, label_token_ids, and model
|
||||||
|
"""
|
||||||
|
# score_data is already in the correct format from build_request
|
||||||
|
return json.dumps(score_data)
|
||||||
|
|
||||||
|
|
||||||
|
async def make_http_call(session, score_data, request_id, results_queue):
|
||||||
|
"""HTTP call to /v1/score endpoint."""
|
||||||
|
try:
|
||||||
|
start_time = asyncio.get_event_loop().time()
|
||||||
|
|
||||||
|
request_json = build_http_request_json(score_data)
|
||||||
|
headers = {"Content-Type": "application/json"}
|
||||||
|
|
||||||
|
async with session.post(HTTP_URL, data=request_json, headers=headers) as resp:
|
||||||
|
resp_text = await resp.text()
|
||||||
|
|
||||||
|
if resp.status != 200:
|
||||||
|
print(
|
||||||
|
f"[HTTP] Request {request_id} failed with status "
|
||||||
|
f"{resp.status}: {resp_text}"
|
||||||
|
)
|
||||||
|
completion_time = asyncio.get_event_loop().time()
|
||||||
|
await results_queue.put((request_id, 0, False, completion_time))
|
||||||
|
return
|
||||||
|
|
||||||
|
# Parse score API response
|
||||||
|
try:
|
||||||
|
response_data = json.loads(resp_text)
|
||||||
|
# Score API returns scores for each item
|
||||||
|
# For now, just verify we got a valid response
|
||||||
|
if "scores" in response_data or "logprobs" in response_data:
|
||||||
|
success = True
|
||||||
|
else:
|
||||||
|
print(
|
||||||
|
f"[HTTP] Request {request_id} missing expected fields in response"
|
||||||
|
)
|
||||||
|
success = False
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
print(f"[HTTP] Request {request_id} failed to parse JSON response")
|
||||||
|
success = False
|
||||||
|
|
||||||
|
completion_time = asyncio.get_event_loop().time()
|
||||||
|
elapsed_time = (completion_time - start_time) * 1000
|
||||||
|
await results_queue.put((request_id, elapsed_time, success, completion_time))
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[HTTP] Error for request {request_id}: {e}")
|
||||||
|
completion_time = asyncio.get_event_loop().time()
|
||||||
|
await results_queue.put((request_id, 0, False, completion_time))
|
||||||
|
|
||||||
|
|
||||||
|
###############################################################################
|
||||||
|
# RESULTS
|
||||||
|
###############################################################################
|
||||||
|
async def process_results(
|
||||||
|
results_queue,
|
||||||
|
num_requests,
|
||||||
|
send_duration,
|
||||||
|
total_duration,
|
||||||
|
rps,
|
||||||
|
duration_secs,
|
||||||
|
item_count,
|
||||||
|
test_start_time,
|
||||||
|
):
|
||||||
|
"""Processes results and groups them by minute intervals.
|
||||||
|
Returns a list of dictionaries, one for each minute."""
|
||||||
|
all_results = []
|
||||||
|
|
||||||
|
# Collect all results
|
||||||
|
for _ in range(num_requests):
|
||||||
|
result = await results_queue.get()
|
||||||
|
request_id, elapsed_time, success, completion_time = result
|
||||||
|
all_results.append(
|
||||||
|
{
|
||||||
|
"request_id": request_id,
|
||||||
|
"elapsed_time": elapsed_time,
|
||||||
|
"success": success,
|
||||||
|
"completion_time": completion_time,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Group results by minute intervals
|
||||||
|
minute_results = []
|
||||||
|
num_minutes = int(duration_secs // 60) + (1 if duration_secs % 60 > 0 else 0)
|
||||||
|
|
||||||
|
for minute in range(num_minutes):
|
||||||
|
minute_start = test_start_time + (minute * 60)
|
||||||
|
minute_end = test_start_time + ((minute + 1) * 60)
|
||||||
|
|
||||||
|
# Filter results that completed in this minute
|
||||||
|
minute_data = [
|
||||||
|
r for r in all_results if minute_start <= r["completion_time"] < minute_end
|
||||||
|
]
|
||||||
|
|
||||||
|
response_times = [r["elapsed_time"] for r in minute_data if r["success"]]
|
||||||
|
successful_requests = len([r for r in minute_data if r["success"]])
|
||||||
|
failed_requests = len([r for r in minute_data if not r["success"]])
|
||||||
|
|
||||||
|
avg_response_time = mean(response_times) if response_times else 0
|
||||||
|
|
||||||
|
# Calculate percentiles using numpy
|
||||||
|
if response_times:
|
||||||
|
p50 = np.percentile(response_times, 50)
|
||||||
|
p90 = np.percentile(response_times, 90)
|
||||||
|
p99 = np.percentile(response_times, 99)
|
||||||
|
else:
|
||||||
|
p50 = p90 = p99 = 0
|
||||||
|
|
||||||
|
minute_result = {
|
||||||
|
"test_duration_secs": duration_secs,
|
||||||
|
"minute_interval": minute + 1,
|
||||||
|
"target_rps": rps,
|
||||||
|
"item_count": item_count,
|
||||||
|
"server_type": SERVER_TYPE,
|
||||||
|
"distribution": DISTRIBUTION,
|
||||||
|
"unique_requests": NUM_UNIQUE_REQUESTS,
|
||||||
|
"total_requests": len(minute_data),
|
||||||
|
"successful_requests": successful_requests,
|
||||||
|
"failed_requests": failed_requests,
|
||||||
|
"send_duration_secs": send_duration,
|
||||||
|
"total_duration_secs": total_duration,
|
||||||
|
"avg_response_time_ms": avg_response_time,
|
||||||
|
"p50_response_time_ms": p50,
|
||||||
|
"p90_response_time_ms": p90,
|
||||||
|
"p99_response_time_ms": p99,
|
||||||
|
}
|
||||||
|
|
||||||
|
minute_results.append(minute_result)
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"\nMinute {minute + 1} Summary for RPS {rps}, "
|
||||||
|
f"Duration {duration_secs}s, Item Count {item_count}:"
|
||||||
|
)
|
||||||
|
print(f" Requests completed in minute: {len(minute_data)}")
|
||||||
|
print(f" Successful requests: {successful_requests}")
|
||||||
|
print(f" Failed requests: {failed_requests}")
|
||||||
|
print(f" Average response time: {avg_response_time:.2f} ms")
|
||||||
|
print(f" P50 response time: {p50:.2f} ms")
|
||||||
|
print(f" P90 response time: {p90:.2f} ms")
|
||||||
|
print(f" P99 response time: {p99:.2f} ms")
|
||||||
|
|
||||||
|
# Also print overall summary
|
||||||
|
all_response_times = [r["elapsed_time"] for r in all_results if r["success"]]
|
||||||
|
total_successful = len([r for r in all_results if r["success"]])
|
||||||
|
total_failed = len([r for r in all_results if not r["success"]])
|
||||||
|
|
||||||
|
overall_avg = mean(all_response_times) if all_response_times else 0
|
||||||
|
if all_response_times:
|
||||||
|
overall_p50 = np.percentile(all_response_times, 50)
|
||||||
|
overall_p90 = np.percentile(all_response_times, 90)
|
||||||
|
overall_p99 = np.percentile(all_response_times, 99)
|
||||||
|
else:
|
||||||
|
overall_p50 = overall_p90 = overall_p99 = 0
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"\nOverall Summary for RPS {rps}, Duration {duration_secs}s, "
|
||||||
|
f"Item Count {item_count}:"
|
||||||
|
)
|
||||||
|
print(f" Test duration: {duration_secs} seconds")
|
||||||
|
print(f" Server type: {SERVER_TYPE}")
|
||||||
|
print(f" HTTP mode: SINGLE_ITEM_SCORING")
|
||||||
|
print(f" Target RPS: {rps}")
|
||||||
|
print(f" Item count: {item_count}")
|
||||||
|
print(f" Distribution: {DISTRIBUTION}")
|
||||||
|
print(f" Unique requests generated: {NUM_UNIQUE_REQUESTS}")
|
||||||
|
print(f" Total requests sent: {num_requests}")
|
||||||
|
print(f" Successful requests: {total_successful}")
|
||||||
|
print(f" Failed requests: {total_failed}")
|
||||||
|
print(f" Time to send all requests: {send_duration:.2f} seconds")
|
||||||
|
print(f" Time for all requests to complete: {total_duration:.2f} seconds")
|
||||||
|
print(f" Average response time: {overall_avg:.2f} ms")
|
||||||
|
print(f" P50 response time: {overall_p50:.2f} ms")
|
||||||
|
print(f" P90 response time: {overall_p90:.2f} ms")
|
||||||
|
print(f" P99 response time: {overall_p99:.2f} ms\n")
|
||||||
|
|
||||||
|
return minute_results
|
||||||
|
|
||||||
|
|
||||||
|
###############################################################################
|
||||||
|
# MAIN
|
||||||
|
###############################################################################
|
||||||
|
async def run_benchmark(rps, duration_secs, item_count):
|
||||||
|
"""Run a single benchmark with the given RPS value."""
|
||||||
|
num_requests = int(rps * duration_secs)
|
||||||
|
print(
|
||||||
|
f"Starting benchmark with RPS={rps}, Duration={duration_secs}s, "
|
||||||
|
f"Item Count={item_count}, num_requests={num_requests}"
|
||||||
|
)
|
||||||
|
print(f"Server Type: {SERVER_TYPE}")
|
||||||
|
print(f"HTTP Mode: SINGLE_ITEM_SCORING")
|
||||||
|
print(f"Profiling Enabled: {PROFILE}")
|
||||||
|
|
||||||
|
# Build requests in parallel (unmeasured)
|
||||||
|
all_requests = prepare_all_requests_parallel(num_requests, item_count)
|
||||||
|
|
||||||
|
results_queue = asyncio.Queue()
|
||||||
|
tasks = []
|
||||||
|
|
||||||
|
# Track timing for sending requests
|
||||||
|
send_start_time = asyncio.get_event_loop().time()
|
||||||
|
|
||||||
|
# HTTP implementation (open source only supports HTTP with /v1/score API)
|
||||||
|
async with aiohttp.ClientSession(
|
||||||
|
timeout=aiohttp.ClientTimeout(total=300)
|
||||||
|
) as session:
|
||||||
|
|
||||||
|
# Send START_PROFILE if profiling is enabled
|
||||||
|
if PROFILE:
|
||||||
|
await send_profile_request("START_PROFILE", item_count, session=session)
|
||||||
|
|
||||||
|
# Add progress bar for sending requests
|
||||||
|
with tqdm(
|
||||||
|
total=len(all_requests),
|
||||||
|
desc=f"Sending HTTP score requests at {rps} RPS",
|
||||||
|
unit="req",
|
||||||
|
) as pbar:
|
||||||
|
for i, score_data in enumerate(all_requests):
|
||||||
|
request_id = i + 1
|
||||||
|
tasks.append(
|
||||||
|
asyncio.create_task(
|
||||||
|
make_http_call(session, score_data, request_id, results_queue)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update progress bar
|
||||||
|
pbar.update(1)
|
||||||
|
|
||||||
|
# Throttle based on distribution
|
||||||
|
if i < len(all_requests) - 1:
|
||||||
|
if DISTRIBUTION == "CONSTANT":
|
||||||
|
interval = 1 / rps
|
||||||
|
await asyncio.sleep(interval)
|
||||||
|
elif DISTRIBUTION == "POISSON":
|
||||||
|
# For Poisson process, inter-arrival times follow
|
||||||
|
# exponential distribution
|
||||||
|
interval = random.expovariate(rps)
|
||||||
|
await asyncio.sleep(interval)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unknown distribution: {DISTRIBUTION}. "
|
||||||
|
f"Use 'CONSTANT' or 'POISSON'."
|
||||||
|
)
|
||||||
|
|
||||||
|
send_end_time = asyncio.get_event_loop().time()
|
||||||
|
send_duration = send_end_time - send_start_time
|
||||||
|
|
||||||
|
# Wait for all requests to complete with progress tracking
|
||||||
|
print(f"Waiting for {len(tasks)} HTTP score requests to complete...")
|
||||||
|
with tqdm(
|
||||||
|
total=len(tasks), desc="Completing HTTP score requests", unit="req"
|
||||||
|
) as completion_pbar:
|
||||||
|
completed_tasks = []
|
||||||
|
for task in asyncio.as_completed(tasks):
|
||||||
|
await task
|
||||||
|
completed_tasks.append(task)
|
||||||
|
completion_pbar.update(1)
|
||||||
|
|
||||||
|
# Send STOP_PROFILE if profiling is enabled
|
||||||
|
if PROFILE:
|
||||||
|
await send_profile_request("STOP_PROFILE", item_count, session=session)
|
||||||
|
|
||||||
|
completion_end_time = asyncio.get_event_loop().time()
|
||||||
|
total_duration = completion_end_time - send_start_time
|
||||||
|
|
||||||
|
return await process_results(
|
||||||
|
results_queue,
|
||||||
|
num_requests,
|
||||||
|
send_duration,
|
||||||
|
total_duration,
|
||||||
|
rps,
|
||||||
|
duration_secs,
|
||||||
|
item_count,
|
||||||
|
send_start_time,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
"""Main function that runs benchmarks for all RPS values."""
|
||||||
|
total_combinations = (
|
||||||
|
len(DURATION_SECS_VALUES) * len(RPS_VALUES) * len(ITEM_COUNT_VALUES)
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f"Running benchmarks for {len(DURATION_SECS_VALUES)} duration "
|
||||||
|
f"values, {len(RPS_VALUES)} RPS values, and "
|
||||||
|
f"{len(ITEM_COUNT_VALUES)} item count values = "
|
||||||
|
f"{total_combinations} total combinations"
|
||||||
|
)
|
||||||
|
print(f"Server Type: {SERVER_TYPE}")
|
||||||
|
print(f"HTTP Mode: SINGLE_ITEM_SCORING")
|
||||||
|
print(f"Score API URL: {HTTP_URL}")
|
||||||
|
print(f"Query tokens per request: {SCORE_QUERY_TOKENS}")
|
||||||
|
print(f"Item tokens per item: {SCORE_ITEM_TOKENS}")
|
||||||
|
print(f"Items per request (batch size): {ITEM_COUNT_VALUES}")
|
||||||
|
print(f"Profiling Enabled: {PROFILE}")
|
||||||
|
print(f"Duration values: {DURATION_SECS_VALUES}")
|
||||||
|
print(f"RPS values: {RPS_VALUES}")
|
||||||
|
print(f"Item count values: {ITEM_COUNT_VALUES}")
|
||||||
|
print("=" * 80)
|
||||||
|
|
||||||
|
all_results = []
|
||||||
|
|
||||||
|
for duration_secs in DURATION_SECS_VALUES:
|
||||||
|
for rps in RPS_VALUES:
|
||||||
|
for item_count in ITEM_COUNT_VALUES:
|
||||||
|
result = await run_benchmark(rps, duration_secs, item_count)
|
||||||
|
all_results.extend(result) # Extend with minute results
|
||||||
|
|
||||||
|
# Print CSV header and results
|
||||||
|
print("\n" + "=" * 80)
|
||||||
|
print("FINAL CSV RESULTS:")
|
||||||
|
print("=" * 80)
|
||||||
|
|
||||||
|
# CSV Header
|
||||||
|
headers = [
|
||||||
|
"test_duration_secs",
|
||||||
|
"minute_interval",
|
||||||
|
"target_rps",
|
||||||
|
"item_count",
|
||||||
|
"server_type",
|
||||||
|
"distribution",
|
||||||
|
"unique_requests",
|
||||||
|
"total_requests",
|
||||||
|
"successful_requests",
|
||||||
|
"failed_requests",
|
||||||
|
"send_duration_secs",
|
||||||
|
"total_duration_secs",
|
||||||
|
"avg_response_time_ms",
|
||||||
|
"p50_response_time_ms",
|
||||||
|
"p90_response_time_ms",
|
||||||
|
"p99_response_time_ms",
|
||||||
|
]
|
||||||
|
print(",".join(headers))
|
||||||
|
|
||||||
|
# CSV Data
|
||||||
|
for result in all_results:
|
||||||
|
row = [
|
||||||
|
result["test_duration_secs"],
|
||||||
|
result["minute_interval"],
|
||||||
|
result["target_rps"],
|
||||||
|
result["item_count"],
|
||||||
|
result["server_type"],
|
||||||
|
result["distribution"],
|
||||||
|
result["unique_requests"],
|
||||||
|
result["total_requests"],
|
||||||
|
result["successful_requests"],
|
||||||
|
result["failed_requests"],
|
||||||
|
f"{result['send_duration_secs']:.2f}",
|
||||||
|
f"{result['total_duration_secs']:.2f}",
|
||||||
|
f"{result['avg_response_time_ms']:.2f}",
|
||||||
|
f"{result['p50_response_time_ms']:.2f}",
|
||||||
|
f"{result['p90_response_time_ms']:.2f}",
|
||||||
|
f"{result['p99_response_time_ms']:.2f}",
|
||||||
|
]
|
||||||
|
print(",".join(map(str, row)))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(main())
|
||||||
@@ -913,6 +913,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|||||||
# Whether to return hidden states
|
# Whether to return hidden states
|
||||||
return_hidden_states: bool = False
|
return_hidden_states: bool = False
|
||||||
|
|
||||||
|
# Whether this batch is prefill-only (no token generation needed)
|
||||||
|
is_prefill_only: bool = False
|
||||||
|
|
||||||
# hicache pointer for synchronizing data loading from CPU to GPU
|
# hicache pointer for synchronizing data loading from CPU to GPU
|
||||||
hicache_consumer_index: int = 0
|
hicache_consumer_index: int = 0
|
||||||
|
|
||||||
@@ -953,6 +956,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|||||||
device=req_to_token_pool.device,
|
device=req_to_token_pool.device,
|
||||||
spec_algorithm=spec_algorithm,
|
spec_algorithm=spec_algorithm,
|
||||||
return_hidden_states=any(req.return_hidden_states for req in reqs),
|
return_hidden_states=any(req.return_hidden_states for req in reqs),
|
||||||
|
is_prefill_only=all(
|
||||||
|
req.sampling_params.max_new_tokens == 0 for req in reqs
|
||||||
|
),
|
||||||
chunked_req=chunked_req,
|
chunked_req=chunked_req,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -1796,6 +1802,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|||||||
global_num_tokens_for_logprob=self.global_num_tokens_for_logprob,
|
global_num_tokens_for_logprob=self.global_num_tokens_for_logprob,
|
||||||
can_run_dp_cuda_graph=self.can_run_dp_cuda_graph,
|
can_run_dp_cuda_graph=self.can_run_dp_cuda_graph,
|
||||||
is_extend_in_batch=self.is_extend_in_batch,
|
is_extend_in_batch=self.is_extend_in_batch,
|
||||||
|
is_prefill_only=self.is_prefill_only,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _evict_tree_cache_if_needed(self, num_tokens: int):
|
def _evict_tree_cache_if_needed(self, num_tokens: int):
|
||||||
|
|||||||
@@ -1466,8 +1466,9 @@ class Scheduler(
|
|||||||
if self.last_batch.batch_size() < last_bs:
|
if self.last_batch.batch_size() < last_bs:
|
||||||
self.running_batch.batch_is_full = False
|
self.running_batch.batch_is_full = False
|
||||||
|
|
||||||
# Merge the new batch into the running batch
|
# Merge the new batch into the running batch.
|
||||||
if not self.last_batch.is_empty():
|
# For prefill-only batch, we can avoid going through decoding step.
|
||||||
|
if not self.last_batch.is_empty() and not self.last_batch.is_prefill_only:
|
||||||
if self.running_batch.is_empty():
|
if self.running_batch.is_empty():
|
||||||
self.running_batch = self.last_batch
|
self.running_batch = self.last_batch
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -699,7 +699,7 @@ class TokenizerManager:
|
|||||||
# Process all requests
|
# Process all requests
|
||||||
tokenized_objs = []
|
tokenized_objs = []
|
||||||
for i, req in enumerate(requests):
|
for i, req in enumerate(requests):
|
||||||
self._validate_token_len(obj[i], input_ids_list[i])
|
self._validate_one_request(obj[i], input_ids_list[i])
|
||||||
tokenized_objs.append(
|
tokenized_objs.append(
|
||||||
self._create_tokenized_object(
|
self._create_tokenized_object(
|
||||||
req, req.text, input_ids_list[i], None, None
|
req, req.text, input_ids_list[i], None, None
|
||||||
@@ -1892,6 +1892,13 @@ class TokenizerManager:
|
|||||||
f"Token ID {token_id} is out of vocabulary (vocab size: {vocab_size})"
|
f"Token ID {token_id} is out of vocabulary (vocab size: {vocab_size})"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
batch_request = GenerateReqInput(
|
||||||
|
token_ids_logprob=label_token_ids,
|
||||||
|
return_logprob=True,
|
||||||
|
stream=False,
|
||||||
|
sampling_params={"max_new_tokens": 0},
|
||||||
|
)
|
||||||
|
|
||||||
# Handle string or tokenized query/items
|
# Handle string or tokenized query/items
|
||||||
if isinstance(query, str) and (
|
if isinstance(query, str) and (
|
||||||
isinstance(items, str)
|
isinstance(items, str)
|
||||||
@@ -1903,13 +1910,9 @@ class TokenizerManager:
|
|||||||
prompts = [f"{item}{query}" for item in items_list]
|
prompts = [f"{item}{query}" for item in items_list]
|
||||||
else:
|
else:
|
||||||
prompts = [f"{query}{item}" for item in items_list]
|
prompts = [f"{query}{item}" for item in items_list]
|
||||||
batch_request = GenerateReqInput(
|
|
||||||
text=prompts,
|
batch_request.text = prompts
|
||||||
return_logprob=True,
|
|
||||||
token_ids_logprob=label_token_ids,
|
|
||||||
stream=False,
|
|
||||||
sampling_params={"max_new_tokens": 1},
|
|
||||||
)
|
|
||||||
elif (
|
elif (
|
||||||
isinstance(query, list)
|
isinstance(query, list)
|
||||||
and isinstance(items, list)
|
and isinstance(items, list)
|
||||||
@@ -1921,13 +1924,8 @@ class TokenizerManager:
|
|||||||
input_ids_list = [item + query for item in items]
|
input_ids_list = [item + query for item in items]
|
||||||
else:
|
else:
|
||||||
input_ids_list = [query + item for item in items]
|
input_ids_list = [query + item for item in items]
|
||||||
batch_request = GenerateReqInput(
|
|
||||||
input_ids=input_ids_list,
|
batch_request.input_ids = input_ids_list
|
||||||
return_logprob=True,
|
|
||||||
token_ids_logprob=label_token_ids,
|
|
||||||
stream=False,
|
|
||||||
sampling_params={"max_new_tokens": 1},
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Invalid combination of query/items types for score_request."
|
"Invalid combination of query/items types for score_request."
|
||||||
@@ -1939,9 +1937,20 @@ class TokenizerManager:
|
|||||||
for result in results:
|
for result in results:
|
||||||
# Get logprobs for each token
|
# Get logprobs for each token
|
||||||
logprobs = {}
|
logprobs = {}
|
||||||
for logprob, token_id, _ in result["meta_info"].get(
|
|
||||||
"output_token_ids_logprobs", []
|
# For scoring requests, we read from output_token_ids_logprobs since we want
|
||||||
)[0]:
|
# the logprobs for specific tokens mentioned in the label_token_ids at
|
||||||
|
# the next position after the last token in the prompt
|
||||||
|
output_logprobs = result["meta_info"].get("output_token_ids_logprobs", [])
|
||||||
|
|
||||||
|
# Throw an error here if output_logprobs is None
|
||||||
|
if output_logprobs is None:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"output_logprobs is None for request {result['meta_info'].get('id', '<unknown>')}. "
|
||||||
|
"This usually indicates a problem with the scoring request or the backend output."
|
||||||
|
)
|
||||||
|
|
||||||
|
for logprob, token_id, _ in output_logprobs[0]:
|
||||||
if token_id in label_token_ids:
|
if token_id in label_token_ids:
|
||||||
logprobs[token_id] = logprob
|
logprobs[token_id] = logprob
|
||||||
|
|
||||||
|
|||||||
@@ -213,6 +213,88 @@ class TestScoreAPI(CustomTestCase):
|
|||||||
1.0, sum(score_list), 6, "Scores should sum to 1"
|
1.0, sum(score_list), 6, "Scores should sum to 1"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_score_request_construction(self):
|
||||||
|
"""Test that scoring requests are constructed to avoid decode phase."""
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
# Capture the internal request to verify optimization
|
||||||
|
captured_requests = []
|
||||||
|
original_gen = self.engine.tokenizer_manager.generate_request
|
||||||
|
|
||||||
|
async def mock_generate_request(req, request=None):
|
||||||
|
captured_requests.append(req)
|
||||||
|
async for result in original_gen(req, request):
|
||||||
|
yield result
|
||||||
|
|
||||||
|
# Patch the generate_request method
|
||||||
|
with patch.object(
|
||||||
|
self.engine.tokenizer_manager,
|
||||||
|
"generate_request",
|
||||||
|
side_effect=mock_generate_request,
|
||||||
|
):
|
||||||
|
# Run a scoring request
|
||||||
|
query = "What is the capital of"
|
||||||
|
items = ["France", "Germany"]
|
||||||
|
label_token_ids = [1, 2, 3]
|
||||||
|
|
||||||
|
scores = self.engine.score(
|
||||||
|
query=query,
|
||||||
|
items=items,
|
||||||
|
label_token_ids=label_token_ids,
|
||||||
|
apply_softmax=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify we got results
|
||||||
|
self.assertEqual(len(scores), len(items))
|
||||||
|
|
||||||
|
# Verify the captured request has decode-avoiding properties
|
||||||
|
self.assertEqual(len(captured_requests), 1)
|
||||||
|
request = captured_requests[0]
|
||||||
|
|
||||||
|
# Key assertions for decode phase avoidance:
|
||||||
|
# 1. max_new_tokens should be 0 (prevents token generation)
|
||||||
|
# Handle both single and batch request cases
|
||||||
|
if isinstance(request.sampling_params, dict):
|
||||||
|
max_new_tokens = request.sampling_params.get("max_new_tokens", 0)
|
||||||
|
elif isinstance(request.sampling_params, list):
|
||||||
|
# For batch requests, check the first item
|
||||||
|
max_new_tokens = request.sampling_params[0].get("max_new_tokens", 0)
|
||||||
|
else:
|
||||||
|
max_new_tokens = getattr(request.sampling_params, "max_new_tokens", 0)
|
||||||
|
|
||||||
|
self.assertEqual(
|
||||||
|
max_new_tokens, 0, "max_new_tokens should be 0 to avoid decode phase"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 2. Should have token_ids_logprob for scoring
|
||||||
|
# Handle both single and batch request cases
|
||||||
|
if (
|
||||||
|
isinstance(request.token_ids_logprob, list)
|
||||||
|
and len(request.token_ids_logprob) > 0
|
||||||
|
and isinstance(request.token_ids_logprob[0], list)
|
||||||
|
):
|
||||||
|
# Batch case: token_ids_logprob is a list of lists
|
||||||
|
# Each item in the batch should have the same label_token_ids
|
||||||
|
for item_token_ids in request.token_ids_logprob:
|
||||||
|
self.assertEqual(
|
||||||
|
item_token_ids,
|
||||||
|
label_token_ids,
|
||||||
|
"Each batch item should have label_token_ids for scoring",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Single request case
|
||||||
|
self.assertEqual(
|
||||||
|
request.token_ids_logprob,
|
||||||
|
label_token_ids,
|
||||||
|
"Should have label_token_ids for scoring",
|
||||||
|
)
|
||||||
|
|
||||||
|
# 3. Should request logprobs but not stream
|
||||||
|
self.assertTrue(
|
||||||
|
request.return_logprob, "Should request logprobs for scoring"
|
||||||
|
)
|
||||||
|
self.assertFalse(request.stream, "Scoring requests should not stream")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
121
test/srt/test_tokenizer_batch_encode.py
Normal file
121
test/srt/test_tokenizer_batch_encode.py
Normal file
@@ -0,0 +1,121 @@
|
|||||||
|
"""
|
||||||
|
Unit tests for enable_tokenizer_batch_encode feature.
|
||||||
|
|
||||||
|
This tests the batch tokenization functionality which allows processing
|
||||||
|
multiple text inputs in a single batch for improved performance.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
python3 -m unittest test_tokenizer_batch_encode.TestTokenizerBatchEncode.test_batch_validation_constraints
|
||||||
|
python3 -m unittest test_tokenizer_batch_encode.TestTokenizerBatchEncodeUnit.test_batch_tokenize_and_process_logic
|
||||||
|
python3 -m unittest test_tokenizer_batch_encode.TestTokenizerBatchEncodeLogic.test_batch_processing_path
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import unittest
|
||||||
|
from typing import List
|
||||||
|
from unittest.mock import AsyncMock, Mock, call, patch
|
||||||
|
|
||||||
|
from sglang.srt.managers.io_struct import GenerateReqInput, TokenizedGenerateReqInput
|
||||||
|
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
||||||
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||||
|
from sglang.test.test_utils import DEFAULT_SMALL_MODEL_NAME_FOR_TEST
|
||||||
|
|
||||||
|
|
||||||
|
class TestTokenizerBatchEncode(unittest.TestCase):
|
||||||
|
"""Test cases for tokenizer batch encoding validation and setup."""
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
"""Set up test fixtures."""
|
||||||
|
self.server_args = ServerArgs(
|
||||||
|
model_path=DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
|
||||||
|
enable_tokenizer_batch_encode=True,
|
||||||
|
)
|
||||||
|
self.port_args = PortArgs.init_new(self.server_args)
|
||||||
|
|
||||||
|
with patch("zmq.asyncio.Context"), patch(
|
||||||
|
"sglang.srt.utils.get_zmq_socket"
|
||||||
|
), patch("sglang.srt.hf_transformers_utils.get_tokenizer") as mock_tokenizer:
|
||||||
|
|
||||||
|
mock_tokenizer.return_value = Mock(vocab_size=32000)
|
||||||
|
self.tokenizer_manager = TokenizerManager(self.server_args, self.port_args)
|
||||||
|
|
||||||
|
def test_batch_encode_enabled(self):
|
||||||
|
"""Test that batch encoding is enabled when configured."""
|
||||||
|
self.assertTrue(self.server_args.enable_tokenizer_batch_encode)
|
||||||
|
|
||||||
|
def test_batch_encode_disabled(self):
|
||||||
|
"""Test that batch encoding can be disabled."""
|
||||||
|
server_args_disabled = ServerArgs(
|
||||||
|
model_path=DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
|
||||||
|
enable_tokenizer_batch_encode=False,
|
||||||
|
)
|
||||||
|
self.assertFalse(server_args_disabled.enable_tokenizer_batch_encode)
|
||||||
|
|
||||||
|
def test_multimodal_input_validation(self):
|
||||||
|
"""Test that multimodal inputs are rejected in batch mode."""
|
||||||
|
req = GenerateReqInput(text="test", image_data=["dummy"])
|
||||||
|
req.contains_mm_input = Mock(return_value=True)
|
||||||
|
|
||||||
|
batch_obj = Mock()
|
||||||
|
batch_obj.__getitem__ = lambda self, i: req
|
||||||
|
|
||||||
|
self.tokenizer_manager.is_generation = True
|
||||||
|
|
||||||
|
with self.assertRaises(ValueError) as cm:
|
||||||
|
self.tokenizer_manager._validate_batch_tokenization_constraints(
|
||||||
|
1, batch_obj
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertIn("multimodal", str(cm.exception))
|
||||||
|
|
||||||
|
def test_pretokenized_input_validation(self):
|
||||||
|
"""Test that pre-tokenized inputs are rejected in batch mode."""
|
||||||
|
req = GenerateReqInput(input_ids=[1, 2, 3])
|
||||||
|
|
||||||
|
batch_obj = Mock()
|
||||||
|
batch_obj.__getitem__ = lambda self, i: req
|
||||||
|
|
||||||
|
with self.assertRaises(ValueError) as cm:
|
||||||
|
self.tokenizer_manager._validate_batch_tokenization_constraints(
|
||||||
|
1, batch_obj
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertIn("pre-tokenized", str(cm.exception))
|
||||||
|
|
||||||
|
def test_input_embeds_validation(self):
|
||||||
|
"""Test that input embeds are rejected in batch mode."""
|
||||||
|
req = GenerateReqInput(input_embeds=[0.1, 0.2])
|
||||||
|
|
||||||
|
batch_obj = Mock()
|
||||||
|
batch_obj.__getitem__ = lambda self, i: req
|
||||||
|
|
||||||
|
with self.assertRaises(ValueError) as cm:
|
||||||
|
self.tokenizer_manager._validate_batch_tokenization_constraints(
|
||||||
|
1, batch_obj
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertIn("input_embeds", str(cm.exception))
|
||||||
|
|
||||||
|
def test_valid_text_only_requests_pass_validation(self):
|
||||||
|
"""Test that valid text-only requests pass validation."""
|
||||||
|
# Create valid requests (text-only)
|
||||||
|
requests = []
|
||||||
|
for i in range(3):
|
||||||
|
req = GenerateReqInput(text=f"test text {i}")
|
||||||
|
req.contains_mm_input = Mock(return_value=False)
|
||||||
|
requests.append(req)
|
||||||
|
|
||||||
|
batch_obj = Mock()
|
||||||
|
batch_obj.__getitem__ = Mock(side_effect=lambda i: requests[i])
|
||||||
|
|
||||||
|
# Should not raise any exception
|
||||||
|
try:
|
||||||
|
self.tokenizer_manager._validate_batch_tokenization_constraints(
|
||||||
|
3, batch_obj
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
self.fail(f"Validation failed for valid text-only requests: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main(verbosity=2)
|
||||||
Reference in New Issue
Block a user