From 793b79dbe901fd2f4257744125f15edcc14567f4 Mon Sep 17 00:00:00 2001 From: Yineng Zhang Date: Sun, 3 Nov 2024 12:56:10 -0800 Subject: [PATCH] feat: support truss endpoint for benchmark serving (#1906) --- python/sglang/bench_serving.py | 92 ++++++++++++++++++++++++++++++++++ 1 file changed, 92 insertions(+) diff --git a/python/sglang/bench_serving.py b/python/sglang/bench_serving.py index 2ca35aca9..8bb452cd0 100644 --- a/python/sglang/bench_serving.py +++ b/python/sglang/bench_serving.py @@ -222,6 +222,85 @@ async def async_request_openai_completions( return output +async def async_request_truss( + request_func_input: RequestFuncInput, + pbar: Optional[tqdm] = None, +) -> RequestFuncOutput: + api_url = request_func_input.api_url + + prompt = request_func_input.prompt + + async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session: + payload = { + "model": request_func_input.model, + "prompt": prompt, + "temperature": 0.0, + "best_of": 1, + "max_tokens": request_func_input.output_len, + "stream": not args.disable_stream, + "ignore_eos": not args.disable_ignore_eos, + **request_func_input.extra_request_body, + } + headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"} + + output = RequestFuncOutput() + output.prompt_len = request_func_input.prompt_len + + generated_text = "" + ttft = 0.0 + st = time.perf_counter() + most_recent_timestamp = st + try: + async with session.post( + url=api_url, json=payload, headers=headers + ) as response: + if response.status == 200: + async for chunk_bytes in response.content: + chunk_bytes = chunk_bytes.strip() + if not chunk_bytes: + continue + + chunk = remove_prefix(chunk_bytes.decode("utf-8"), "data: ") + latency = time.perf_counter() - st + if chunk == "[DONE]": + pass + else: + data = json.loads(chunk) + + # NOTE: Some completion API might have a last + # usage summary response without a token so we + # want to check a token was generated + if data["choices"][0]["delta"]["content"]: + timestamp = time.perf_counter() + # First token + if ttft == 0.0: + ttft = time.perf_counter() - st + output.ttft = ttft + + # Decoding phase + else: + output.itl.append(timestamp - most_recent_timestamp) + + most_recent_timestamp = timestamp + generated_text += data["choices"][0]["delta"]["content"] + + output.generated_text = generated_text + output.success = True + output.latency = latency + output.output_len = request_func_input.output_len + else: + output.error = response.reason or "" + output.success = False + except Exception: + output.success = False + exc_info = sys.exc_info() + output.error = "".join(traceback.format_exception(*exc_info)) + + if pbar: + pbar.update(1) + return output + + async def async_request_sglang_generate( request_func_input: RequestFuncInput, pbar: Optional[tqdm] = None, @@ -350,6 +429,7 @@ ASYNC_REQUEST_FUNCS = { "lmdeploy": async_request_openai_completions, "trt": async_request_trt_llm, "gserver": async_request_gserver, + "truss": async_request_truss, } @@ -873,6 +953,7 @@ def run_benchmark(args_: argparse.Namespace): "vllm": 8000, "trt": 8000, "gserver": 9988, + "truss": 8080, }.get(args.backend, 30000) model_url = ( @@ -905,9 +986,20 @@ def run_benchmark(args_: argparse.Namespace): elif args.backend == "gserver": api_url = args.base_url if args.base_url else f"{args.host}:{args.port}" args.model = args.model or "default" + elif args.backend == "truss": + api_url = ( + f"{args.base_url}/v1/models/model:predict" + if args.base_url + else f"http://{args.host}:{args.port}/v1/models/model:predict" + ) # Get model name if args.model is None: + if args.backend == "truss": + print( + "Please provide a model with `--model` when using truss backend. e.g. --model meta-llama/Llama-3.1-8B-Instruct" + ) + sys.exit(1) try: response = requests.get(model_url) model_list = response.json().get("data", [])