diff --git a/benchmark/lora/launch_server.py b/benchmark/lora/launch_server.py new file mode 100644 index 000000000..1fa4d7135 --- /dev/null +++ b/benchmark/lora/launch_server.py @@ -0,0 +1,53 @@ +import argparse +import os + +NUM_LORAS = 128 +LORA_PATH = { + "base": "mistralai/Mistral-7B-Instruct-v0.3", + "lora": "/home/ying/test_lora", +} + + +def launch_server(args): + base_path = LORA_PATH["base"] + lora_path = LORA_PATH["lora"] + max_loras_per_batch = 4 + + if args.base_only: + cmd = f"python -m sglang.launch_server --model {base_path} " + else: + cmd = f"python -m sglang.launch_server --model {base_path} --lora-paths " + for i in range(NUM_LORAS): + lora_name = f"lora{i}" + cmd += f"{lora_name}={lora_path} " + cmd += f"--disable-radix --disable-cuda-graph " + cmd += f"--max-loras-per-batch {args.max_loras_per_batch} " + cmd += f"--max-running-requests {args.max_running_requests}" + print(cmd) + os.system(cmd) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--num-loras", + type=int, + default=128, + ) + parser.add_argument( + "--base-only", + action="store_true", + ) + parser.add_argument( + "--max-loras-per-batch", + type=int, + default=8, + ) + parser.add_argument( + "--max-running-requests", + type=int, + default=8, + ) + args = parser.parse_args() + + launch_server(args) diff --git a/benchmark/lora/lora_bench.py b/benchmark/lora/lora_bench.py new file mode 100644 index 000000000..24087067b --- /dev/null +++ b/benchmark/lora/lora_bench.py @@ -0,0 +1,485 @@ +""" +Copyright 2023-2024 SGLang Team +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import argparse +import asyncio +import json +import os +import random +import resource +import sys +import time +import traceback +import warnings +from argparse import ArgumentParser +from dataclasses import dataclass, field +from datetime import datetime +from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, Union + +import aiohttp +import numpy as np +import requests +from launch_server import LORA_PATH, NUM_LORAS +from tqdm.asyncio import tqdm +from transformers import ( + AutoTokenizer, + PreTrainedTokenizer, + PreTrainedTokenizerBase, + PreTrainedTokenizerFast, +) + +from sglang.bench_serving import ( + AIOHTTP_TIMEOUT, + SHAREGPT_URL, + BenchmarkMetrics, + RequestFuncInput, + RequestFuncOutput, + calculate_metrics, + check_chat_template, + get_model, + get_request, + get_tokenizer, + parse_request_rate_range, + remove_prefix, + sample_random_requests, +) + +global args + + +# set ignore_eos True by default +async def async_request_openai_completions( + request_func_input: RequestFuncInput, + pbar: Optional[tqdm] = None, +) -> RequestFuncOutput: + api_url = request_func_input.api_url + # assert api_url.endswith( + # "completions" + # ), "OpenAI Completions API URL must end with 'completions'." + + prompt = request_func_input.prompt + + async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session: + # payload = { + # "model": request_func_input.model, + # "prompt": prompt, + # "temperature": 0.0, + # "best_of": 1, + # "max_tokens": request_func_input.output_len, + # "stream": not args.disable_stream, + # "ignore_eos": not args.disable_ignore_eos, + # **request_func_input.extra_request_body, + # } + # headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"} + if args.base_only: + payload = { + "text": prompt, + "sampling_params": {"max_new_tokens": request_func_input.output_len}, + } + else: + payload = { + "text": prompt, + "sampling_params": {"max_new_tokens": request_func_input.output_len}, + "lora_path": f"lora{random.randint(0, NUM_LORAS - 1)}", + } + headers = {"Authorization": ""} + + output = RequestFuncOutput() + output.prompt_len = request_func_input.prompt_len + + generated_text = "" + ttft = 0.0 + st = time.perf_counter() + most_recent_timestamp = st + try: + async with session.post( + url=api_url, json=payload, headers=headers + ) as response: + if response.status == 200: + async for chunk_bytes in response.content: + chunk_bytes = chunk_bytes.strip() + if not chunk_bytes: + continue + + chunk = remove_prefix(chunk_bytes.decode("utf-8"), "data: ") + latency = time.perf_counter() - st + if chunk == "[DONE]": + pass + else: + data = json.loads(chunk) + + # NOTE: Some completion API might have a last + # usage summary response without a token so we + # want to check a token was generated + if data["text"]: + # if data["choices"][0]["text"]: + timestamp = time.perf_counter() + # First token + if ttft == 0.0: + ttft = time.perf_counter() - st + output.ttft = ttft + + # Decoding phase + else: + output.itl.append(timestamp - most_recent_timestamp) + + most_recent_timestamp = timestamp + # generated_text += data["choices"][0]["text"] + generated_text += data["text"] + + output.generated_text = generated_text + output.success = True + output.latency = latency + output.output_len = request_func_input.output_len + else: + output.error = response.reason or "" + output.success = False + except Exception: + output.success = False + exc_info = sys.exc_info() + output.error = "".join(traceback.format_exception(*exc_info)) + + if pbar: + pbar.update(1) + return output + + +ASYNC_REQUEST_FUNCS = { + "sglang": async_request_openai_completions, +} + + +async def benchmark( + backend: str, + api_url: str, + model_id: str, + tokenizer: PreTrainedTokenizerBase, + input_requests: List[Tuple[str, int, int]], + request_rate: float, + disable_tqdm: bool, + extra_request_body: Dict[str, Any], +): + if backend in ASYNC_REQUEST_FUNCS: + request_func = ASYNC_REQUEST_FUNCS[backend] + else: + raise ValueError(f"Unknown backend: {backend}") + + print("Starting initial single prompt test run...") + test_prompt, test_prompt_len, test_output_len = input_requests[0] + test_input = RequestFuncInput( + model=model_id, + prompt=test_prompt, + api_url=api_url, + prompt_len=test_prompt_len, + output_len=test_output_len, + extra_request_body=extra_request_body, + ) + test_output = await request_func(request_func_input=test_input) + if not test_output.success: + raise ValueError( + "Initial test run failed - Please make sure benchmark arguments " + f"are correctly specified. Error: {test_output.error}" + ) + else: + print("Initial test run completed. Starting main benchmark run...") + + pbar = None if disable_tqdm else tqdm(total=len(input_requests)) + + benchmark_start_time = time.perf_counter() + tasks: List[asyncio.Task] = [] + async for request in get_request(input_requests, request_rate): + prompt, prompt_len, output_len = request + request_func_input = RequestFuncInput( + model=model_id, + prompt=prompt, + api_url=api_url, + prompt_len=prompt_len, + output_len=output_len, + extra_request_body=extra_request_body, + ) + tasks.append( + asyncio.create_task( + request_func(request_func_input=request_func_input, pbar=pbar) + ) + ) + outputs: List[RequestFuncOutput] = await asyncio.gather(*tasks) + + if pbar is not None: + pbar.close() + + benchmark_duration = time.perf_counter() - benchmark_start_time + + metrics, output_lens = calculate_metrics( + input_requests=input_requests, + outputs=outputs, + dur_s=benchmark_duration, + tokenizer=tokenizer, + backend=backend, + ) + + print("\n{s:{c}^{n}}".format(s=" Serving Benchmark Result ", n=50, c="=")) + print("{:<40} {:<10}".format("Backend:", backend)) + print("{:<40} {:<10}".format("Traffic request rate:", request_rate)) + print("{:<40} {:<10}".format("Successful requests:", metrics.completed)) + print("{:<40} {:<10.2f}".format("Benchmark duration (s):", benchmark_duration)) + print("{:<40} {:<10}".format("Total input tokens:", metrics.total_input)) + print("{:<40} {:<10}".format("Total generated tokens:", metrics.total_output)) + print( + "{:<40} {:<10}".format( + "Total generated tokens (retokenized):", metrics.total_output_retokenized + ) + ) + print( + "{:<40} {:<10.2f}".format( + "Request throughput (req/s):", metrics.request_throughput + ) + ) + print( + "{:<40} {:<10.2f}".format( + "Input token throughput (tok/s):", metrics.input_throughput + ) + ) + print( + "{:<40} {:<10.2f}".format( + "Output token throughput (tok/s):", metrics.output_throughput + ) + ) + print("{s:{c}^{n}}".format(s="End-to-End Latency", n=50, c="-")) + print( + "{:<40} {:<10.2f}".format("Mean E2E Latency (ms):", metrics.mean_e2e_latency_ms) + ) + print( + "{:<40} {:<10.2f}".format( + "Median E2E Latency (ms):", metrics.median_e2e_latency_ms + ) + ) + print("{s:{c}^{n}}".format(s="Time to First Token", n=50, c="-")) + print("{:<40} {:<10.2f}".format("Mean TTFT (ms):", metrics.mean_ttft_ms)) + print("{:<40} {:<10.2f}".format("Median TTFT (ms):", metrics.median_ttft_ms)) + print("{:<40} {:<10.2f}".format("P99 TTFT (ms):", metrics.p99_ttft_ms)) + print( + "{s:{c}^{n}}".format(s="Time per Output Token (excl. 1st token)", n=50, c="-") + ) + print("{:<40} {:<10.2f}".format("Mean TPOT (ms):", metrics.mean_tpot_ms)) + print("{:<40} {:<10.2f}".format("Median TPOT (ms):", metrics.median_tpot_ms)) + print("{:<40} {:<10.2f}".format("P99 TPOT (ms):", metrics.p99_tpot_ms)) + print("{s:{c}^{n}}".format(s="Inter-token Latency", n=50, c="-")) + print("{:<40} {:<10.2f}".format("Mean ITL (ms):", metrics.mean_itl_ms)) + print("{:<40} {:<10.2f}".format("Median ITL (ms):", metrics.median_itl_ms)) + print("{:<40} {:<10.2f}".format("P99 ITL (ms):", metrics.p99_itl_ms)) + print("=" * 50) + + if ( + metrics.median_ttft_ms is not None + and metrics.mean_itl_ms is not None + and metrics.output_throughput is not None + ): + result = { + "backend": args.backend, + "request_rate": request_rate, + "total_input_tokens": metrics.total_input, + "total_output_tokens": metrics.total_output, + "total_output_tokens_retokenized": metrics.total_output_retokenized, + "mean_e2e_latency_ms": metrics.mean_e2e_latency_ms, + "median_e2e_latency_ms": metrics.median_e2e_latency_ms, + "median_ttft_ms": metrics.median_ttft_ms, + "median_itl_ms": metrics.median_itl_ms, + "output_throughput": metrics.output_throughput, + "random_input_len": args.random_input_len, + "random_output_len": args.random_output_len, + "random_range_ratio": args.random_range_ratio, + "duration": benchmark_duration, + "completed": metrics.completed, + } + else: + print(f"Error running benchmark for request rate: {request_rate}") + print("-" * 30) + + # Determine output file name + if args.output_file: + output_file_name = args.output_file + else: + now = datetime.now().strftime("%m%d") + output_file_name = f"{args.backend}_{now}_{args.num_prompts}_{args.random_input_len}_{args.random_output_len}.jsonl" + + # Append results to a JSONL file + with open(output_file_name, "a") as file: + file.write(json.dumps(result) + "\n") + + result = { + "duration": benchmark_duration, + "completed": metrics.completed, + "total_input_tokens": metrics.total_input, + "total_output_tokens": metrics.total_output, + "total_output_tokens_retokenized": metrics.total_output_retokenized, + "request_throughput": metrics.request_throughput, + "input_throughput": metrics.input_throughput, + "output_throughput": metrics.output_throughput, + "mean_ttft_ms": metrics.mean_ttft_ms, + "median_ttft_ms": metrics.median_ttft_ms, + "std_ttft_ms": metrics.std_ttft_ms, + "p99_ttft_ms": metrics.p99_ttft_ms, + "mean_tpot_ms": metrics.mean_tpot_ms, + "median_tpot_ms": metrics.median_tpot_ms, + "std_tpot_ms": metrics.std_tpot_ms, + "p99_tpot_ms": metrics.p99_tpot_ms, + "mean_itl_ms": metrics.mean_itl_ms, + "median_itl_ms": metrics.median_itl_ms, + "std_itl_ms": metrics.std_itl_ms, + "p99_itl_ms": metrics.p99_itl_ms, + "input_lens": [output.prompt_len for output in outputs], + "output_lens": output_lens, + "ttfts": [output.ttft for output in outputs], + "itls": [output.itl for output in outputs], + "generated_texts": [output.generated_text for output in outputs], + "errors": [output.error for output in outputs], + "mean_e2e_latency_ms": metrics.mean_e2e_latency_ms, + "median_e2e_latency_ms": metrics.median_e2e_latency_ms, + } + return result + + +def run_benchmark(args_: argparse.Namespace): + global args + args = args_ + + # Set global environments + set_ulimit() + random.seed(args.seed) + np.random.seed(args.seed) + + # Set url + if args.port is None: + args.port = { + "sglang": 30000, + }.get(args.backend, 30000) + + # api_url = ( + # f"{args.base_url}/v1/completions" + # if args.base_url + # else f"http://{args.host}:{args.port}/v1/completions" + # ) + api_url = ( + f"{args.base_url}/generate" + if args.base_url + else f"http://{args.host}:{args.port}/generate" + ) + + print(f"{args}\n") + + # Read dataset + backend = args.backend + model_id = args.model = LORA_PATH["base"] + tokenizer_id = args.model + + tokenizer = get_tokenizer(tokenizer_id) + + input_requests = sample_random_requests( + input_len=args.random_input_len, + output_len=args.random_output_len, + num_prompts=args.num_prompts, + range_ratio=args.random_range_ratio, + tokenizer=tokenizer, + dataset_path="", + ) + + return asyncio.run( + benchmark( + backend=backend, + api_url=api_url, + model_id=model_id, + tokenizer=tokenizer, + input_requests=input_requests, + request_rate=args.request_rate, + disable_tqdm=False, + extra_request_body={}, + ) + ) + + +def set_ulimit(target_soft_limit=65535): + resource_type = resource.RLIMIT_NOFILE + current_soft, current_hard = resource.getrlimit(resource_type) + + if current_soft < target_soft_limit: + try: + resource.setrlimit(resource_type, (target_soft_limit, current_hard)) + except ValueError as e: + print(f"Fail to set RLIMIT_NOFILE: {e}") + + +if __name__ == "__main__": + parser = ArgumentParser(description="Benchmark the online lora serving throughput.") + parser.add_argument( + "--backend", + type=str, + choices=list(ASYNC_REQUEST_FUNCS.keys()), + default="sglang", + help="Must specify a backend, depending on the LLM Inference Engine.", + ) + parser.add_argument( + "--base-url", + type=str, + default=None, + help="Server or API base url if not using http host and port.", + ) + parser.add_argument( + "--host", type=str, default="0.0.0.0", help="Default host is 0.0.0.0." + ) + parser.add_argument( + "--port", + type=int, + help="If not set, the default port is configured according to its default value for different LLM Inference Engines.", + ) + parser.add_argument( + "--num-prompts", + type=int, + default=50, + help="Number of prompts to process. Default is 1000.", + ) + parser.add_argument( + "--random-input-len", + type=int, + default=1024, + help="Number of input tokens per request, used only for random dataset.", + ) + parser.add_argument( + "--random-output-len", + type=int, + default=128, + help="Number of output tokens per request, used only for random dataset.", + ) + parser.add_argument( + "--random-range-ratio", + type=float, + default=0.0, + help="Range of sampled ratio of input/output length, " + "used only for random dataset.", + ) + parser.add_argument( + "--request-rate", + type=float, + default=float("inf"), + help="Number of requests per second. If this is inf, then all the requests are sent at time 0. " + "Otherwise, we use Poisson process to synthesize the request arrival times. Default is inf.", + ) + parser.add_argument( + "--base-only", + action="store_true", + ) + parser.add_argument("--output-file", type=str, help="Output JSONL file name.") + parser.add_argument("--seed", type=int, default=1, help="The random seed.") + args = parser.parse_args() + run_benchmark(args) diff --git a/examples/runtime/lora.py b/examples/runtime/lora.py new file mode 100644 index 000000000..183cfb484 --- /dev/null +++ b/examples/runtime/lora.py @@ -0,0 +1,37 @@ +# launch server +# python -m sglang.launch_server --model mistralai/Mistral-7B-Instruct-v0.3 --lora-paths /home/ying/test_lora /home/ying/test_lora_1 /home/ying/test_lora_2 lora3=/home/ying/test_lora_3 lora4=/home/ying/test_lora_4 --disable-radix --disable-cuda-graph --max-loras-per-batch 4 + +# send requests +# lora_path[i] specifies the LoRA used for text[i], so make sure they have the same length +# use None to specify base-only prompt, e.x. "lora_path": [None, "/home/ying/test_lora"] +import json + +import requests + +url = "http://127.0.0.1:30000" +json_data = { + "text": [ + "prompt 1", + "prompt 2", + "prompt 3", + "prompt 4", + "prompt 5", + "prompt 6", + "prompt 7", + ], + "sampling_params": {"max_new_tokens": 32}, + "lora_path": [ + "/home/ying/test_lora", + "/home/ying/test_lora_1", + "/home/ying/test_lora_2", + "lora3", + "lora4", + "/home/ying/test_lora", + "/home/ying/test_lora_1", + ], +} +response = requests.post( + url + "/generate", + json=json_data, +) +print(json.dumps(response.json())) diff --git a/python/sglang/srt/lora/lora_manager.py b/python/sglang/srt/lora/lora_manager.py index 5e11280a4..a9ea232ed 100644 --- a/python/sglang/srt/lora/lora_manager.py +++ b/python/sglang/srt/lora/lora_manager.py @@ -96,10 +96,10 @@ class LoRAManager: # get configs and target modules self.configs = {} self.origin_target_modules = set() - for path in self.lora_paths: - self.configs[path] = LoRAConfig(path) + for name, path in self.lora_paths.items(): + self.configs[name] = LoRAConfig(path) self.origin_target_modules = set(self.origin_target_modules) | set( - self.configs[path].target_modules + self.configs[name].target_modules ) self.target_modules = set( [ @@ -114,11 +114,11 @@ class LoRAManager: # load all weights to cpu self.loras = [] self.lora_id = {} - for path in self.lora_paths: - self.lora_id[path] = len(self.loras) + for name in self.lora_paths.keys(): + self.lora_id[name] = len(self.loras) self.loras.append( LoRAAdapter( - path, self.configs[path], self.base_hf_config, self.load_config + name, self.configs[name], self.base_hf_config, self.load_config ) ) self.loras[-1].initialize_weights() diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index a35c5b423..11769b57f 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -24,6 +24,17 @@ from typing import List, Optional, Union logger = logging.getLogger(__name__) +class LoRAPathAction(argparse.Action): + def __call__(self, parser, namespace, values, option_string=None): + setattr(namespace, self.dest, {}) + for lora_path in values: + if "=" in lora_path: + name, path = lora_path.split("=", 1) + getattr(namespace, self.dest)[name] = path + else: + getattr(namespace, self.dest)[lora_path] = lora_path + + @dataclasses.dataclass class ServerArgs: # Model and tokenizer @@ -532,7 +543,8 @@ class ServerArgs: type=str, nargs="*", default=None, - help="The list of LoRA adapters.", + action=LoRAPathAction, + help="The list of LoRA adapters. You can provide a list of either path in str or renamed path in the format {name}={path}", ) parser.add_argument( "--max-loras-per-batch", diff --git a/scripts/playground/lora/test_lora.py b/scripts/playground/lora/test_lora.py deleted file mode 100644 index 069020c42..000000000 --- a/scripts/playground/lora/test_lora.py +++ /dev/null @@ -1,55 +0,0 @@ -import json - -import openai -import requests - -import sglang as sgl - -lora_path = "/home/ying/test_lora" -prompt_file = "/home/ying/test_prompt/dialogue_choice_prompts.json" -server_url = "http://127.0.0.1:30000" - -client = openai.Client(base_url=server_url + "/v1", api_key="EMPTY") - - -# @sgl.function -# def generate(s, prompt): -# s += prompt -# s += sgl.gen("ans") - -# sgl.set_default_backend(sgl.RuntimeEndpoint(server_url)) - - -def generate(prompt, lora_path): - json_data = { - "text": prompt, - "sampling_params": {}, - "return_logprob": False, - "logprob_start_len": None, - "top_logprobs_num": None, - "lora_path": lora_path, - } - response = requests.post( - server_url + "/generate", - json=json_data, - ) - return json.dumps(response.json()) - - -with open(prompt_file, "r") as f: - samples = json.load(f) - - -for sample in samples[:1]: - assert sample[0]["role"] == "user" - prompt = sample[0]["content"] - assert sample[1]["role"] == "assistant" - ref = sample[1]["content"] - - state = generate(prompt, lora_path) - print("================================") - print(ref) - print("--------------------------------") - # print(state["ans"]) - print(state) - print()