diff --git a/benchmark/hicache/bench_multiturn.py b/benchmark/hicache/bench_multiturn.py new file mode 100644 index 000000000..ab34c33da --- /dev/null +++ b/benchmark/hicache/bench_multiturn.py @@ -0,0 +1,334 @@ +import argparse +import asyncio +import json +import queue +import random +import threading +import time +from typing import Optional + +import aiohttp +import requests +from tqdm.asyncio import tqdm + +from sglang.bench_serving import ( + RequestFuncOutput, + get_tokenizer, + remove_prefix, + sample_random_requests, +) + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Script to benchmark concurrent requests to a server." + ) + parser.add_argument( + "--num-clients", + type=int, + default=200, + help="Number of concurrent clients", + ) + parser.add_argument( + "--request-length", + type=int, + default=512, + help="Length of each new request", + ) + parser.add_argument( + "--output-length", + type=int, + default=64, + help="Length of each output", + ) + parser.add_argument( + "--num-rounds", + type=int, + default=5, + help="Number of rounds per client", + ) + parser.add_argument( + "--distribution", + type=str, + default="poisson", + choices=["poisson", "uniform"], + help="Distribution type for request intervals (poisson or uniform)", + ) + parser.add_argument( + "--request-rate", + type=float, + default=1.0, + help="Average number of requests per second", + ) + parser.add_argument( + "--host", + type=str, + default="localhost", + help="Server hostname or IP (default: localhost)", + ) + parser.add_argument( + "--port", + type=int, + default=30000, + help="Server port (default: 30000)", + ) + parser.add_argument( + "--model", + type=str, + default="meta-llama/Llama-3.1-8B-Instruct", + help="model path compatible with Hugging Face Transformers", + ) + return parser.parse_args() + + +async def async_request_sglang_generate( + payload, + url, + pbar: Optional[tqdm] = None, +): + """ + Sends a streaming request to the server. Gathers text token-by-token. + """ + async with aiohttp.ClientSession() as session: + headers = {} + generated_text = "" + ttft = 0.0 + st = time.perf_counter() + most_recent_timestamp = st + output = RequestFuncOutput() + + try: + async with session.post(url=url, json=payload, headers=headers) as response: + if response.status == 200: + async for chunk_bytes in response.content: + chunk_bytes = chunk_bytes.strip() + if not chunk_bytes: + continue + + chunk = remove_prefix(chunk_bytes.decode("utf-8"), "data: ") + latency = time.perf_counter() - st + if chunk == "[DONE]": + pass + else: + data = json.loads(chunk) + + if data["text"]: + timestamp = time.perf_counter() + # First token + if ttft == 0.0: + ttft = time.perf_counter() - st + output.ttft = ttft + + # Decoding phase + else: + output.itl.append(timestamp - most_recent_timestamp) + + most_recent_timestamp = timestamp + generated_text = data["text"] + + output.generated_text = generated_text + output.success = True + output.latency = latency + else: + output.error = response.reason or "" + output.success = False + except Exception as e: + output.success = False + output.error = str(e) + print(f"Request failed: {e}") + + if pbar: + pbar.update(1) + return output + + +def gen_payload(prompt, output_len): + payload = { + "text": prompt, + "sampling_params": { + "temperature": 0.0, + "max_new_tokens": output_len, + "ignore_eos": True, + }, + "stream": True, + "lora_path": "", + "return_logprob": False, + "logprob_start_len": -1, + } + return payload + + +class ReadyQueue: + """ + Thread-safe queue that can pop requests in different orders based on given policy. + """ + + def __init__(self, init_requests=None, policy="random"): + self.lock = threading.Lock() + self.requests = init_requests or [] + self.policy = policy + + def append(self, item): + with self.lock: + self.requests.append(item) + + def pop(self): + with self.lock: + if not self.requests: + return None + if self.policy == "random": + index = random.randrange(len(self.requests)) + return self.requests.pop(index) + elif self.policy == "fifo": + return self.requests.pop(0) + else: + # todo, varying thinking time of clients + raise ValueError(f"{self.policy} not implemented") + + +class WorkloadGenerator: + def __init__(self, args): + # Construct the base URL for requests + self.url = f"http://{args.host}:{args.port}/generate" + + self.tokenizer = get_tokenizer(args.model) + self.distribution = args.distribution + self.request_rate = args.request_rate + self.start_time = None + self.finished_time = None + + self.candidate_inputs = sample_random_requests( + input_len=args.request_length, + output_len=args.output_length, + num_prompts=args.num_clients * args.num_rounds, + range_ratio=1.0, + tokenizer=self.tokenizer, + dataset_path="", + ) + self.candidate_inputs = [i[0] for i in self.candidate_inputs] + + init_requests = [ + (i, gen_payload(self.candidate_inputs[i], args.output_length)) + for i in range(args.num_clients) + ] + self.client_records = { + i: {"round": 0, "history": init_requests[i][1]["text"]} + for i in range(args.num_clients) + } + self.ready_queue = ReadyQueue(init_requests=init_requests) + self.candidate_inputs = self.candidate_inputs[args.num_clients :] + + self.response_queue = queue.Queue() + self.pbar = tqdm(total=args.num_clients * args.num_rounds) + self.performance_metrics = {"ttft": [], "latency": []} + + async def handle_request(self, item): + try: + client_id, payload = item + response = await async_request_sglang_generate(payload, self.url, self.pbar) + if self.pbar.n == self.pbar.total: + self.finished_time = time.time() + self.response_queue.put((client_id, response)) + except Exception as e: + print(f"Request failed: {e}") + + def request_sender(self): + async def request_loop(): + while True: + # Calculate Poisson-distributed wait time + if self.distribution == "poisson": + sleep_time = random.expovariate(self.request_rate) + elif self.distribution == "uniform": + avg_interval = ( + 1.0 / self.request_rate if self.request_rate > 0 else 1.0 + ) + sleep_time = random.uniform(0, 2 * avg_interval) + else: + raise ValueError("Invalid distribution type") + await asyncio.sleep(sleep_time) # Wait before sending the next request + + new_request = self.ready_queue.pop() + # Submit async request + if new_request: + asyncio.create_task(self.handle_request(new_request)) + else: + if self.pbar.n == self.pbar.total: + break + + # Create and run the event loop for asynchronous requests + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + loop.run_until_complete(request_loop()) + loop.close() + + def response_handler(self): + while True: + try: + client_id, response = self.response_queue.get( + timeout=10 + ) # Block until response is available + if not response.success: + raise ValueError(f"Request failed with error: {response.error}") + self.client_records[client_id]["history"] += response.generated_text + self.client_records[client_id]["round"] += 1 + self.performance_metrics["ttft"].append(response.ttft) + self.performance_metrics["latency"].append(response.latency) + + if self.client_records[client_id]["round"] < args.num_rounds: + self.client_records[client_id][ + "history" + ] += self.candidate_inputs.pop() + self.ready_queue.append( + ( + client_id, + gen_payload( + self.client_records[client_id]["history"], + args.output_length, + ), + ) + ) + except queue.Empty: + if self.pbar.n == self.pbar.total: + break + + def run(self): + request_thread = threading.Thread(target=self.request_sender, daemon=True) + response_thread = threading.Thread(target=self.response_handler, daemon=True) + + self.start_time = time.time() + request_thread.start() + response_thread.start() + + request_thread.join() + response_thread.join() + + self.pbar.close() + print("All requests completed.") + print("Performance metrics summary:") + print( + f" Total requests: {len(self.performance_metrics['ttft'])} at {self.request_rate} requests per second" + ) + print( + f" Average TTFT: {sum(self.performance_metrics['ttft']) / len(self.performance_metrics['ttft']):.2f}" + ) + print( + f" Median TTFT: {sorted(self.performance_metrics['ttft'])[len(self.performance_metrics['ttft']) // 2]:.2f}" + ) + print( + f" Average latency: {sum(self.performance_metrics['latency']) / len(self.performance_metrics['latency']):.2f}" + ) + print( + f" Median latency: {sorted(self.performance_metrics['latency'])[len(self.performance_metrics['latency']) // 2]:.2f}" + ) + throughput = self.pbar.total / (self.finished_time - self.start_time) + print(f"Throughput: {throughput:.2f} requests per second") + + +if __name__ == "__main__": + args = parse_args() + flush_cache_url = f"http://{args.host}:{args.port}/flush_cache" + + for request_rate in range(1, 41, 2): + args.request_rate = request_rate + requests.post(flush_cache_url) + WorkloadGenerator(args).run()