Simple prefetch policy (#8692)

This commit is contained in:
pansicheng
2025-08-08 17:09:28 +08:00
committed by GitHub
parent 7490e3f67d
commit e2fd2b9c7e
6 changed files with 148 additions and 36 deletions

View File

@@ -20,6 +20,8 @@ from sglang.bench_serving import (
sample_random_requests,
)
AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=20 * 60 * 60)
def parse_args():
parser = argparse.ArgumentParser(
@@ -139,7 +141,7 @@ async def async_request_sglang_generate(
"""
Sends a streaming request to the server. Gathers text token-by-token.
"""
async with aiohttp.ClientSession() as session:
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
headers = {}
generated_text = ""
ttft = 0.0
@@ -150,6 +152,8 @@ async def async_request_sglang_generate(
try:
async with session.post(url=url, json=payload, headers=headers) as response:
if response.status == 200:
prompt_tokens = 0
cached_tokens = 0
async for chunk_bytes in response.content:
chunk_bytes = chunk_bytes.strip()
if not chunk_bytes:
@@ -168,6 +172,12 @@ async def async_request_sglang_generate(
if ttft == 0.0:
ttft = time.perf_counter() - st
output.ttft = ttft
prompt_tokens = (data.get("meta_info") or {}).get(
"prompt_tokens", 0
)
cached_tokens = (data.get("meta_info") or {}).get(
"cached_tokens", 0
)
# Decoding phase
else:
@@ -179,6 +189,8 @@ async def async_request_sglang_generate(
output.generated_text = generated_text
output.success = True
output.latency = latency
output.prompt_len = prompt_tokens
output.cached_tokens = cached_tokens
else:
output.error = response.reason or ""
output.success = False
@@ -201,6 +213,7 @@ def gen_payload(prompt, output_len):
"ignore_eos": True,
},
"stream": True,
"stream_options": {"include_usage": True},
"lora_path": "",
"return_logprob": False,
"logprob_start_len": -1,
@@ -303,7 +316,12 @@ class WorkloadGenerator:
self.response_queue = queue.Queue()
self.pbar = tqdm(total=args.num_clients * args.num_rounds)
self.performance_metrics = {"ttft": [], "latency": []}
self.performance_metrics = {
"ttft": [],
"latency": [],
"prompt_len": [],
"cached_tokens": [],
}
async def handle_request(self, item):
try:
@@ -360,6 +378,8 @@ class WorkloadGenerator:
self.client_records[client_id]["round"] += 1
self.performance_metrics["ttft"].append(response.ttft)
self.performance_metrics["latency"].append(response.latency)
self.performance_metrics["prompt_len"].append(response.prompt_len)
self.performance_metrics["cached_tokens"].append(response.cached_tokens)
self.completed_requests += 1
if self.client_records[client_id]["round"] < args.num_rounds:
@@ -416,6 +436,12 @@ class WorkloadGenerator:
len(self.performance_metrics["latency"]) // 2
],
"throughput": self.pbar.total / (self.finished_time - self.start_time),
"cache_hit_rate": (
0
if sum(self.performance_metrics["prompt_len"]) == 0
else sum(self.performance_metrics["cached_tokens"])
/ sum(self.performance_metrics["prompt_len"])
),
},
}
print("All requests completed")
@@ -434,6 +460,7 @@ class WorkloadGenerator:
print(
f" Throughput: {performance_data['summary']['throughput']:.2f} requests per second"
)
print(f" Cache Hit Rate: {performance_data['summary']['cache_hit_rate']:.6f}")
log_to_jsonl_file(performance_data, args.log_file, tag=args.tag)