[Feature] Define backends and add Triton backend for Lora (#3161)
Co-authored-by: Ying Sheng <sqy1415@gmail.com>
This commit is contained in:
@@ -1,10 +1,10 @@
|
||||
import argparse
|
||||
import os
|
||||
|
||||
NUM_LORAS = 8
|
||||
NUM_LORAS = 4
|
||||
LORA_PATH = {
|
||||
"base": "mistralai/Mistral-7B-Instruct-v0.3",
|
||||
"lora": "/home/ying/test_lora",
|
||||
"base": "meta-llama/Llama-2-7b-hf",
|
||||
"lora": "winddude/wizardLM-LlaMA-LoRA-7B",
|
||||
}
|
||||
|
||||
|
||||
@@ -21,7 +21,8 @@ def launch_server(args):
|
||||
cmd += f"{lora_name}={lora_path} "
|
||||
cmd += f"--disable-radix --disable-cuda-graph "
|
||||
cmd += f"--max-loras-per-batch {args.max_loras_per_batch} "
|
||||
cmd += f"--max-running-requests {args.max_running_requests}"
|
||||
cmd += f"--max-running-requests {args.max_running_requests} "
|
||||
cmd += f"--lora-backend {args.lora_backend}"
|
||||
print(cmd)
|
||||
os.system(cmd)
|
||||
|
||||
@@ -42,6 +43,11 @@ if __name__ == "__main__":
|
||||
type=int,
|
||||
default=8,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lora-backend",
|
||||
type=str,
|
||||
default="triton",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
launch_server(args)
|
||||
|
||||
@@ -183,6 +183,7 @@ async def benchmark(
|
||||
api_url=api_url,
|
||||
prompt_len=test_prompt_len,
|
||||
output_len=test_output_len,
|
||||
lora_name="dummy", # the lora_name argument will not be used
|
||||
extra_request_body=extra_request_body,
|
||||
)
|
||||
test_output = await request_func(request_func_input=test_input)
|
||||
@@ -206,6 +207,7 @@ async def benchmark(
|
||||
api_url=api_url,
|
||||
prompt_len=prompt_len,
|
||||
output_len=output_len,
|
||||
lora_name="dummy",
|
||||
extra_request_body=extra_request_body,
|
||||
)
|
||||
tasks.append(
|
||||
@@ -255,6 +257,9 @@ async def benchmark(
|
||||
"Output token throughput (tok/s):", metrics.output_throughput
|
||||
)
|
||||
)
|
||||
print(
|
||||
"{:<40} {:<10.2f}".format("Total throughput (tok/s):", metrics.total_throughput)
|
||||
)
|
||||
print("{s:{c}^{n}}".format(s="End-to-End Latency", n=50, c="-"))
|
||||
print(
|
||||
"{:<40} {:<10.2f}".format("Mean E2E Latency (ms):", metrics.mean_e2e_latency_ms)
|
||||
|
||||
Reference in New Issue
Block a user