Simple prefetch policy (#8692)
This commit is contained in:
@@ -20,6 +20,8 @@ from sglang.bench_serving import (
|
||||
sample_random_requests,
|
||||
)
|
||||
|
||||
AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=20 * 60 * 60)
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
@@ -139,7 +141,7 @@ async def async_request_sglang_generate(
|
||||
"""
|
||||
Sends a streaming request to the server. Gathers text token-by-token.
|
||||
"""
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
|
||||
headers = {}
|
||||
generated_text = ""
|
||||
ttft = 0.0
|
||||
@@ -150,6 +152,8 @@ async def async_request_sglang_generate(
|
||||
try:
|
||||
async with session.post(url=url, json=payload, headers=headers) as response:
|
||||
if response.status == 200:
|
||||
prompt_tokens = 0
|
||||
cached_tokens = 0
|
||||
async for chunk_bytes in response.content:
|
||||
chunk_bytes = chunk_bytes.strip()
|
||||
if not chunk_bytes:
|
||||
@@ -168,6 +172,12 @@ async def async_request_sglang_generate(
|
||||
if ttft == 0.0:
|
||||
ttft = time.perf_counter() - st
|
||||
output.ttft = ttft
|
||||
prompt_tokens = (data.get("meta_info") or {}).get(
|
||||
"prompt_tokens", 0
|
||||
)
|
||||
cached_tokens = (data.get("meta_info") or {}).get(
|
||||
"cached_tokens", 0
|
||||
)
|
||||
|
||||
# Decoding phase
|
||||
else:
|
||||
@@ -179,6 +189,8 @@ async def async_request_sglang_generate(
|
||||
output.generated_text = generated_text
|
||||
output.success = True
|
||||
output.latency = latency
|
||||
output.prompt_len = prompt_tokens
|
||||
output.cached_tokens = cached_tokens
|
||||
else:
|
||||
output.error = response.reason or ""
|
||||
output.success = False
|
||||
@@ -201,6 +213,7 @@ def gen_payload(prompt, output_len):
|
||||
"ignore_eos": True,
|
||||
},
|
||||
"stream": True,
|
||||
"stream_options": {"include_usage": True},
|
||||
"lora_path": "",
|
||||
"return_logprob": False,
|
||||
"logprob_start_len": -1,
|
||||
@@ -303,7 +316,12 @@ class WorkloadGenerator:
|
||||
|
||||
self.response_queue = queue.Queue()
|
||||
self.pbar = tqdm(total=args.num_clients * args.num_rounds)
|
||||
self.performance_metrics = {"ttft": [], "latency": []}
|
||||
self.performance_metrics = {
|
||||
"ttft": [],
|
||||
"latency": [],
|
||||
"prompt_len": [],
|
||||
"cached_tokens": [],
|
||||
}
|
||||
|
||||
async def handle_request(self, item):
|
||||
try:
|
||||
@@ -360,6 +378,8 @@ class WorkloadGenerator:
|
||||
self.client_records[client_id]["round"] += 1
|
||||
self.performance_metrics["ttft"].append(response.ttft)
|
||||
self.performance_metrics["latency"].append(response.latency)
|
||||
self.performance_metrics["prompt_len"].append(response.prompt_len)
|
||||
self.performance_metrics["cached_tokens"].append(response.cached_tokens)
|
||||
self.completed_requests += 1
|
||||
|
||||
if self.client_records[client_id]["round"] < args.num_rounds:
|
||||
@@ -416,6 +436,12 @@ class WorkloadGenerator:
|
||||
len(self.performance_metrics["latency"]) // 2
|
||||
],
|
||||
"throughput": self.pbar.total / (self.finished_time - self.start_time),
|
||||
"cache_hit_rate": (
|
||||
0
|
||||
if sum(self.performance_metrics["prompt_len"]) == 0
|
||||
else sum(self.performance_metrics["cached_tokens"])
|
||||
/ sum(self.performance_metrics["prompt_len"])
|
||||
),
|
||||
},
|
||||
}
|
||||
print("All requests completed")
|
||||
@@ -434,6 +460,7 @@ class WorkloadGenerator:
|
||||
print(
|
||||
f" Throughput: {performance_data['summary']['throughput']:.2f} requests per second"
|
||||
)
|
||||
print(f" Cache Hit Rate: {performance_data['summary']['cache_hit_rate']:.6f}")
|
||||
log_to_jsonl_file(performance_data, args.log_file, tag=args.tag)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user