[Feature] Support LoRA path renaming and add LoRA serving benchmarks (#1433)
This commit is contained in:
53
benchmark/lora/launch_server.py
Normal file
53
benchmark/lora/launch_server.py
Normal file
@@ -0,0 +1,53 @@
|
||||
import argparse
|
||||
import os
|
||||
|
||||
NUM_LORAS = 128
|
||||
LORA_PATH = {
|
||||
"base": "mistralai/Mistral-7B-Instruct-v0.3",
|
||||
"lora": "/home/ying/test_lora",
|
||||
}
|
||||
|
||||
|
||||
def launch_server(args):
|
||||
base_path = LORA_PATH["base"]
|
||||
lora_path = LORA_PATH["lora"]
|
||||
max_loras_per_batch = 4
|
||||
|
||||
if args.base_only:
|
||||
cmd = f"python -m sglang.launch_server --model {base_path} "
|
||||
else:
|
||||
cmd = f"python -m sglang.launch_server --model {base_path} --lora-paths "
|
||||
for i in range(NUM_LORAS):
|
||||
lora_name = f"lora{i}"
|
||||
cmd += f"{lora_name}={lora_path} "
|
||||
cmd += f"--disable-radix --disable-cuda-graph "
|
||||
cmd += f"--max-loras-per-batch {args.max_loras_per_batch} "
|
||||
cmd += f"--max-running-requests {args.max_running_requests}"
|
||||
print(cmd)
|
||||
os.system(cmd)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--num-loras",
|
||||
type=int,
|
||||
default=128,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--base-only",
|
||||
action="store_true",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-loras-per-batch",
|
||||
type=int,
|
||||
default=8,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-running-requests",
|
||||
type=int,
|
||||
default=8,
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
launch_server(args)
|
||||
485
benchmark/lora/lora_bench.py
Normal file
485
benchmark/lora/lora_bench.py
Normal file
@@ -0,0 +1,485 @@
|
||||
"""
|
||||
Copyright 2023-2024 SGLang Team
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import resource
|
||||
import sys
|
||||
import time
|
||||
import traceback
|
||||
import warnings
|
||||
from argparse import ArgumentParser
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import aiohttp
|
||||
import numpy as np
|
||||
import requests
|
||||
from launch_server import LORA_PATH, NUM_LORAS
|
||||
from tqdm.asyncio import tqdm
|
||||
from transformers import (
|
||||
AutoTokenizer,
|
||||
PreTrainedTokenizer,
|
||||
PreTrainedTokenizerBase,
|
||||
PreTrainedTokenizerFast,
|
||||
)
|
||||
|
||||
from sglang.bench_serving import (
|
||||
AIOHTTP_TIMEOUT,
|
||||
SHAREGPT_URL,
|
||||
BenchmarkMetrics,
|
||||
RequestFuncInput,
|
||||
RequestFuncOutput,
|
||||
calculate_metrics,
|
||||
check_chat_template,
|
||||
get_model,
|
||||
get_request,
|
||||
get_tokenizer,
|
||||
parse_request_rate_range,
|
||||
remove_prefix,
|
||||
sample_random_requests,
|
||||
)
|
||||
|
||||
global args
|
||||
|
||||
|
||||
# set ignore_eos True by default
|
||||
async def async_request_openai_completions(
|
||||
request_func_input: RequestFuncInput,
|
||||
pbar: Optional[tqdm] = None,
|
||||
) -> RequestFuncOutput:
|
||||
api_url = request_func_input.api_url
|
||||
# assert api_url.endswith(
|
||||
# "completions"
|
||||
# ), "OpenAI Completions API URL must end with 'completions'."
|
||||
|
||||
prompt = request_func_input.prompt
|
||||
|
||||
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
|
||||
# payload = {
|
||||
# "model": request_func_input.model,
|
||||
# "prompt": prompt,
|
||||
# "temperature": 0.0,
|
||||
# "best_of": 1,
|
||||
# "max_tokens": request_func_input.output_len,
|
||||
# "stream": not args.disable_stream,
|
||||
# "ignore_eos": not args.disable_ignore_eos,
|
||||
# **request_func_input.extra_request_body,
|
||||
# }
|
||||
# headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"}
|
||||
if args.base_only:
|
||||
payload = {
|
||||
"text": prompt,
|
||||
"sampling_params": {"max_new_tokens": request_func_input.output_len},
|
||||
}
|
||||
else:
|
||||
payload = {
|
||||
"text": prompt,
|
||||
"sampling_params": {"max_new_tokens": request_func_input.output_len},
|
||||
"lora_path": f"lora{random.randint(0, NUM_LORAS - 1)}",
|
||||
}
|
||||
headers = {"Authorization": ""}
|
||||
|
||||
output = RequestFuncOutput()
|
||||
output.prompt_len = request_func_input.prompt_len
|
||||
|
||||
generated_text = ""
|
||||
ttft = 0.0
|
||||
st = time.perf_counter()
|
||||
most_recent_timestamp = st
|
||||
try:
|
||||
async with session.post(
|
||||
url=api_url, json=payload, headers=headers
|
||||
) as response:
|
||||
if response.status == 200:
|
||||
async for chunk_bytes in response.content:
|
||||
chunk_bytes = chunk_bytes.strip()
|
||||
if not chunk_bytes:
|
||||
continue
|
||||
|
||||
chunk = remove_prefix(chunk_bytes.decode("utf-8"), "data: ")
|
||||
latency = time.perf_counter() - st
|
||||
if chunk == "[DONE]":
|
||||
pass
|
||||
else:
|
||||
data = json.loads(chunk)
|
||||
|
||||
# NOTE: Some completion API might have a last
|
||||
# usage summary response without a token so we
|
||||
# want to check a token was generated
|
||||
if data["text"]:
|
||||
# if data["choices"][0]["text"]:
|
||||
timestamp = time.perf_counter()
|
||||
# First token
|
||||
if ttft == 0.0:
|
||||
ttft = time.perf_counter() - st
|
||||
output.ttft = ttft
|
||||
|
||||
# Decoding phase
|
||||
else:
|
||||
output.itl.append(timestamp - most_recent_timestamp)
|
||||
|
||||
most_recent_timestamp = timestamp
|
||||
# generated_text += data["choices"][0]["text"]
|
||||
generated_text += data["text"]
|
||||
|
||||
output.generated_text = generated_text
|
||||
output.success = True
|
||||
output.latency = latency
|
||||
output.output_len = request_func_input.output_len
|
||||
else:
|
||||
output.error = response.reason or ""
|
||||
output.success = False
|
||||
except Exception:
|
||||
output.success = False
|
||||
exc_info = sys.exc_info()
|
||||
output.error = "".join(traceback.format_exception(*exc_info))
|
||||
|
||||
if pbar:
|
||||
pbar.update(1)
|
||||
return output
|
||||
|
||||
|
||||
ASYNC_REQUEST_FUNCS = {
|
||||
"sglang": async_request_openai_completions,
|
||||
}
|
||||
|
||||
|
||||
async def benchmark(
|
||||
backend: str,
|
||||
api_url: str,
|
||||
model_id: str,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
input_requests: List[Tuple[str, int, int]],
|
||||
request_rate: float,
|
||||
disable_tqdm: bool,
|
||||
extra_request_body: Dict[str, Any],
|
||||
):
|
||||
if backend in ASYNC_REQUEST_FUNCS:
|
||||
request_func = ASYNC_REQUEST_FUNCS[backend]
|
||||
else:
|
||||
raise ValueError(f"Unknown backend: {backend}")
|
||||
|
||||
print("Starting initial single prompt test run...")
|
||||
test_prompt, test_prompt_len, test_output_len = input_requests[0]
|
||||
test_input = RequestFuncInput(
|
||||
model=model_id,
|
||||
prompt=test_prompt,
|
||||
api_url=api_url,
|
||||
prompt_len=test_prompt_len,
|
||||
output_len=test_output_len,
|
||||
extra_request_body=extra_request_body,
|
||||
)
|
||||
test_output = await request_func(request_func_input=test_input)
|
||||
if not test_output.success:
|
||||
raise ValueError(
|
||||
"Initial test run failed - Please make sure benchmark arguments "
|
||||
f"are correctly specified. Error: {test_output.error}"
|
||||
)
|
||||
else:
|
||||
print("Initial test run completed. Starting main benchmark run...")
|
||||
|
||||
pbar = None if disable_tqdm else tqdm(total=len(input_requests))
|
||||
|
||||
benchmark_start_time = time.perf_counter()
|
||||
tasks: List[asyncio.Task] = []
|
||||
async for request in get_request(input_requests, request_rate):
|
||||
prompt, prompt_len, output_len = request
|
||||
request_func_input = RequestFuncInput(
|
||||
model=model_id,
|
||||
prompt=prompt,
|
||||
api_url=api_url,
|
||||
prompt_len=prompt_len,
|
||||
output_len=output_len,
|
||||
extra_request_body=extra_request_body,
|
||||
)
|
||||
tasks.append(
|
||||
asyncio.create_task(
|
||||
request_func(request_func_input=request_func_input, pbar=pbar)
|
||||
)
|
||||
)
|
||||
outputs: List[RequestFuncOutput] = await asyncio.gather(*tasks)
|
||||
|
||||
if pbar is not None:
|
||||
pbar.close()
|
||||
|
||||
benchmark_duration = time.perf_counter() - benchmark_start_time
|
||||
|
||||
metrics, output_lens = calculate_metrics(
|
||||
input_requests=input_requests,
|
||||
outputs=outputs,
|
||||
dur_s=benchmark_duration,
|
||||
tokenizer=tokenizer,
|
||||
backend=backend,
|
||||
)
|
||||
|
||||
print("\n{s:{c}^{n}}".format(s=" Serving Benchmark Result ", n=50, c="="))
|
||||
print("{:<40} {:<10}".format("Backend:", backend))
|
||||
print("{:<40} {:<10}".format("Traffic request rate:", request_rate))
|
||||
print("{:<40} {:<10}".format("Successful requests:", metrics.completed))
|
||||
print("{:<40} {:<10.2f}".format("Benchmark duration (s):", benchmark_duration))
|
||||
print("{:<40} {:<10}".format("Total input tokens:", metrics.total_input))
|
||||
print("{:<40} {:<10}".format("Total generated tokens:", metrics.total_output))
|
||||
print(
|
||||
"{:<40} {:<10}".format(
|
||||
"Total generated tokens (retokenized):", metrics.total_output_retokenized
|
||||
)
|
||||
)
|
||||
print(
|
||||
"{:<40} {:<10.2f}".format(
|
||||
"Request throughput (req/s):", metrics.request_throughput
|
||||
)
|
||||
)
|
||||
print(
|
||||
"{:<40} {:<10.2f}".format(
|
||||
"Input token throughput (tok/s):", metrics.input_throughput
|
||||
)
|
||||
)
|
||||
print(
|
||||
"{:<40} {:<10.2f}".format(
|
||||
"Output token throughput (tok/s):", metrics.output_throughput
|
||||
)
|
||||
)
|
||||
print("{s:{c}^{n}}".format(s="End-to-End Latency", n=50, c="-"))
|
||||
print(
|
||||
"{:<40} {:<10.2f}".format("Mean E2E Latency (ms):", metrics.mean_e2e_latency_ms)
|
||||
)
|
||||
print(
|
||||
"{:<40} {:<10.2f}".format(
|
||||
"Median E2E Latency (ms):", metrics.median_e2e_latency_ms
|
||||
)
|
||||
)
|
||||
print("{s:{c}^{n}}".format(s="Time to First Token", n=50, c="-"))
|
||||
print("{:<40} {:<10.2f}".format("Mean TTFT (ms):", metrics.mean_ttft_ms))
|
||||
print("{:<40} {:<10.2f}".format("Median TTFT (ms):", metrics.median_ttft_ms))
|
||||
print("{:<40} {:<10.2f}".format("P99 TTFT (ms):", metrics.p99_ttft_ms))
|
||||
print(
|
||||
"{s:{c}^{n}}".format(s="Time per Output Token (excl. 1st token)", n=50, c="-")
|
||||
)
|
||||
print("{:<40} {:<10.2f}".format("Mean TPOT (ms):", metrics.mean_tpot_ms))
|
||||
print("{:<40} {:<10.2f}".format("Median TPOT (ms):", metrics.median_tpot_ms))
|
||||
print("{:<40} {:<10.2f}".format("P99 TPOT (ms):", metrics.p99_tpot_ms))
|
||||
print("{s:{c}^{n}}".format(s="Inter-token Latency", n=50, c="-"))
|
||||
print("{:<40} {:<10.2f}".format("Mean ITL (ms):", metrics.mean_itl_ms))
|
||||
print("{:<40} {:<10.2f}".format("Median ITL (ms):", metrics.median_itl_ms))
|
||||
print("{:<40} {:<10.2f}".format("P99 ITL (ms):", metrics.p99_itl_ms))
|
||||
print("=" * 50)
|
||||
|
||||
if (
|
||||
metrics.median_ttft_ms is not None
|
||||
and metrics.mean_itl_ms is not None
|
||||
and metrics.output_throughput is not None
|
||||
):
|
||||
result = {
|
||||
"backend": args.backend,
|
||||
"request_rate": request_rate,
|
||||
"total_input_tokens": metrics.total_input,
|
||||
"total_output_tokens": metrics.total_output,
|
||||
"total_output_tokens_retokenized": metrics.total_output_retokenized,
|
||||
"mean_e2e_latency_ms": metrics.mean_e2e_latency_ms,
|
||||
"median_e2e_latency_ms": metrics.median_e2e_latency_ms,
|
||||
"median_ttft_ms": metrics.median_ttft_ms,
|
||||
"median_itl_ms": metrics.median_itl_ms,
|
||||
"output_throughput": metrics.output_throughput,
|
||||
"random_input_len": args.random_input_len,
|
||||
"random_output_len": args.random_output_len,
|
||||
"random_range_ratio": args.random_range_ratio,
|
||||
"duration": benchmark_duration,
|
||||
"completed": metrics.completed,
|
||||
}
|
||||
else:
|
||||
print(f"Error running benchmark for request rate: {request_rate}")
|
||||
print("-" * 30)
|
||||
|
||||
# Determine output file name
|
||||
if args.output_file:
|
||||
output_file_name = args.output_file
|
||||
else:
|
||||
now = datetime.now().strftime("%m%d")
|
||||
output_file_name = f"{args.backend}_{now}_{args.num_prompts}_{args.random_input_len}_{args.random_output_len}.jsonl"
|
||||
|
||||
# Append results to a JSONL file
|
||||
with open(output_file_name, "a") as file:
|
||||
file.write(json.dumps(result) + "\n")
|
||||
|
||||
result = {
|
||||
"duration": benchmark_duration,
|
||||
"completed": metrics.completed,
|
||||
"total_input_tokens": metrics.total_input,
|
||||
"total_output_tokens": metrics.total_output,
|
||||
"total_output_tokens_retokenized": metrics.total_output_retokenized,
|
||||
"request_throughput": metrics.request_throughput,
|
||||
"input_throughput": metrics.input_throughput,
|
||||
"output_throughput": metrics.output_throughput,
|
||||
"mean_ttft_ms": metrics.mean_ttft_ms,
|
||||
"median_ttft_ms": metrics.median_ttft_ms,
|
||||
"std_ttft_ms": metrics.std_ttft_ms,
|
||||
"p99_ttft_ms": metrics.p99_ttft_ms,
|
||||
"mean_tpot_ms": metrics.mean_tpot_ms,
|
||||
"median_tpot_ms": metrics.median_tpot_ms,
|
||||
"std_tpot_ms": metrics.std_tpot_ms,
|
||||
"p99_tpot_ms": metrics.p99_tpot_ms,
|
||||
"mean_itl_ms": metrics.mean_itl_ms,
|
||||
"median_itl_ms": metrics.median_itl_ms,
|
||||
"std_itl_ms": metrics.std_itl_ms,
|
||||
"p99_itl_ms": metrics.p99_itl_ms,
|
||||
"input_lens": [output.prompt_len for output in outputs],
|
||||
"output_lens": output_lens,
|
||||
"ttfts": [output.ttft for output in outputs],
|
||||
"itls": [output.itl for output in outputs],
|
||||
"generated_texts": [output.generated_text for output in outputs],
|
||||
"errors": [output.error for output in outputs],
|
||||
"mean_e2e_latency_ms": metrics.mean_e2e_latency_ms,
|
||||
"median_e2e_latency_ms": metrics.median_e2e_latency_ms,
|
||||
}
|
||||
return result
|
||||
|
||||
|
||||
def run_benchmark(args_: argparse.Namespace):
|
||||
global args
|
||||
args = args_
|
||||
|
||||
# Set global environments
|
||||
set_ulimit()
|
||||
random.seed(args.seed)
|
||||
np.random.seed(args.seed)
|
||||
|
||||
# Set url
|
||||
if args.port is None:
|
||||
args.port = {
|
||||
"sglang": 30000,
|
||||
}.get(args.backend, 30000)
|
||||
|
||||
# api_url = (
|
||||
# f"{args.base_url}/v1/completions"
|
||||
# if args.base_url
|
||||
# else f"http://{args.host}:{args.port}/v1/completions"
|
||||
# )
|
||||
api_url = (
|
||||
f"{args.base_url}/generate"
|
||||
if args.base_url
|
||||
else f"http://{args.host}:{args.port}/generate"
|
||||
)
|
||||
|
||||
print(f"{args}\n")
|
||||
|
||||
# Read dataset
|
||||
backend = args.backend
|
||||
model_id = args.model = LORA_PATH["base"]
|
||||
tokenizer_id = args.model
|
||||
|
||||
tokenizer = get_tokenizer(tokenizer_id)
|
||||
|
||||
input_requests = sample_random_requests(
|
||||
input_len=args.random_input_len,
|
||||
output_len=args.random_output_len,
|
||||
num_prompts=args.num_prompts,
|
||||
range_ratio=args.random_range_ratio,
|
||||
tokenizer=tokenizer,
|
||||
dataset_path="",
|
||||
)
|
||||
|
||||
return asyncio.run(
|
||||
benchmark(
|
||||
backend=backend,
|
||||
api_url=api_url,
|
||||
model_id=model_id,
|
||||
tokenizer=tokenizer,
|
||||
input_requests=input_requests,
|
||||
request_rate=args.request_rate,
|
||||
disable_tqdm=False,
|
||||
extra_request_body={},
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def set_ulimit(target_soft_limit=65535):
|
||||
resource_type = resource.RLIMIT_NOFILE
|
||||
current_soft, current_hard = resource.getrlimit(resource_type)
|
||||
|
||||
if current_soft < target_soft_limit:
|
||||
try:
|
||||
resource.setrlimit(resource_type, (target_soft_limit, current_hard))
|
||||
except ValueError as e:
|
||||
print(f"Fail to set RLIMIT_NOFILE: {e}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = ArgumentParser(description="Benchmark the online lora serving throughput.")
|
||||
parser.add_argument(
|
||||
"--backend",
|
||||
type=str,
|
||||
choices=list(ASYNC_REQUEST_FUNCS.keys()),
|
||||
default="sglang",
|
||||
help="Must specify a backend, depending on the LLM Inference Engine.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--base-url",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Server or API base url if not using http host and port.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--host", type=str, default="0.0.0.0", help="Default host is 0.0.0.0."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--port",
|
||||
type=int,
|
||||
help="If not set, the default port is configured according to its default value for different LLM Inference Engines.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-prompts",
|
||||
type=int,
|
||||
default=50,
|
||||
help="Number of prompts to process. Default is 1000.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--random-input-len",
|
||||
type=int,
|
||||
default=1024,
|
||||
help="Number of input tokens per request, used only for random dataset.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--random-output-len",
|
||||
type=int,
|
||||
default=128,
|
||||
help="Number of output tokens per request, used only for random dataset.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--random-range-ratio",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="Range of sampled ratio of input/output length, "
|
||||
"used only for random dataset.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--request-rate",
|
||||
type=float,
|
||||
default=float("inf"),
|
||||
help="Number of requests per second. If this is inf, then all the requests are sent at time 0. "
|
||||
"Otherwise, we use Poisson process to synthesize the request arrival times. Default is inf.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--base-only",
|
||||
action="store_true",
|
||||
)
|
||||
parser.add_argument("--output-file", type=str, help="Output JSONL file name.")
|
||||
parser.add_argument("--seed", type=int, default=1, help="The random seed.")
|
||||
args = parser.parse_args()
|
||||
run_benchmark(args)
|
||||
37
examples/runtime/lora.py
Normal file
37
examples/runtime/lora.py
Normal file
@@ -0,0 +1,37 @@
|
||||
# launch server
|
||||
# python -m sglang.launch_server --model mistralai/Mistral-7B-Instruct-v0.3 --lora-paths /home/ying/test_lora /home/ying/test_lora_1 /home/ying/test_lora_2 lora3=/home/ying/test_lora_3 lora4=/home/ying/test_lora_4 --disable-radix --disable-cuda-graph --max-loras-per-batch 4
|
||||
|
||||
# send requests
|
||||
# lora_path[i] specifies the LoRA used for text[i], so make sure they have the same length
|
||||
# use None to specify base-only prompt, e.x. "lora_path": [None, "/home/ying/test_lora"]
|
||||
import json
|
||||
|
||||
import requests
|
||||
|
||||
url = "http://127.0.0.1:30000"
|
||||
json_data = {
|
||||
"text": [
|
||||
"prompt 1",
|
||||
"prompt 2",
|
||||
"prompt 3",
|
||||
"prompt 4",
|
||||
"prompt 5",
|
||||
"prompt 6",
|
||||
"prompt 7",
|
||||
],
|
||||
"sampling_params": {"max_new_tokens": 32},
|
||||
"lora_path": [
|
||||
"/home/ying/test_lora",
|
||||
"/home/ying/test_lora_1",
|
||||
"/home/ying/test_lora_2",
|
||||
"lora3",
|
||||
"lora4",
|
||||
"/home/ying/test_lora",
|
||||
"/home/ying/test_lora_1",
|
||||
],
|
||||
}
|
||||
response = requests.post(
|
||||
url + "/generate",
|
||||
json=json_data,
|
||||
)
|
||||
print(json.dumps(response.json()))
|
||||
@@ -96,10 +96,10 @@ class LoRAManager:
|
||||
# get configs and target modules
|
||||
self.configs = {}
|
||||
self.origin_target_modules = set()
|
||||
for path in self.lora_paths:
|
||||
self.configs[path] = LoRAConfig(path)
|
||||
for name, path in self.lora_paths.items():
|
||||
self.configs[name] = LoRAConfig(path)
|
||||
self.origin_target_modules = set(self.origin_target_modules) | set(
|
||||
self.configs[path].target_modules
|
||||
self.configs[name].target_modules
|
||||
)
|
||||
self.target_modules = set(
|
||||
[
|
||||
@@ -114,11 +114,11 @@ class LoRAManager:
|
||||
# load all weights to cpu
|
||||
self.loras = []
|
||||
self.lora_id = {}
|
||||
for path in self.lora_paths:
|
||||
self.lora_id[path] = len(self.loras)
|
||||
for name in self.lora_paths.keys():
|
||||
self.lora_id[name] = len(self.loras)
|
||||
self.loras.append(
|
||||
LoRAAdapter(
|
||||
path, self.configs[path], self.base_hf_config, self.load_config
|
||||
name, self.configs[name], self.base_hf_config, self.load_config
|
||||
)
|
||||
)
|
||||
self.loras[-1].initialize_weights()
|
||||
|
||||
@@ -24,6 +24,17 @@ from typing import List, Optional, Union
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LoRAPathAction(argparse.Action):
|
||||
def __call__(self, parser, namespace, values, option_string=None):
|
||||
setattr(namespace, self.dest, {})
|
||||
for lora_path in values:
|
||||
if "=" in lora_path:
|
||||
name, path = lora_path.split("=", 1)
|
||||
getattr(namespace, self.dest)[name] = path
|
||||
else:
|
||||
getattr(namespace, self.dest)[lora_path] = lora_path
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ServerArgs:
|
||||
# Model and tokenizer
|
||||
@@ -532,7 +543,8 @@ class ServerArgs:
|
||||
type=str,
|
||||
nargs="*",
|
||||
default=None,
|
||||
help="The list of LoRA adapters.",
|
||||
action=LoRAPathAction,
|
||||
help="The list of LoRA adapters. You can provide a list of either path in str or renamed path in the format {name}={path}",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-loras-per-batch",
|
||||
|
||||
@@ -1,55 +0,0 @@
|
||||
import json
|
||||
|
||||
import openai
|
||||
import requests
|
||||
|
||||
import sglang as sgl
|
||||
|
||||
lora_path = "/home/ying/test_lora"
|
||||
prompt_file = "/home/ying/test_prompt/dialogue_choice_prompts.json"
|
||||
server_url = "http://127.0.0.1:30000"
|
||||
|
||||
client = openai.Client(base_url=server_url + "/v1", api_key="EMPTY")
|
||||
|
||||
|
||||
# @sgl.function
|
||||
# def generate(s, prompt):
|
||||
# s += prompt
|
||||
# s += sgl.gen("ans")
|
||||
|
||||
# sgl.set_default_backend(sgl.RuntimeEndpoint(server_url))
|
||||
|
||||
|
||||
def generate(prompt, lora_path):
|
||||
json_data = {
|
||||
"text": prompt,
|
||||
"sampling_params": {},
|
||||
"return_logprob": False,
|
||||
"logprob_start_len": None,
|
||||
"top_logprobs_num": None,
|
||||
"lora_path": lora_path,
|
||||
}
|
||||
response = requests.post(
|
||||
server_url + "/generate",
|
||||
json=json_data,
|
||||
)
|
||||
return json.dumps(response.json())
|
||||
|
||||
|
||||
with open(prompt_file, "r") as f:
|
||||
samples = json.load(f)
|
||||
|
||||
|
||||
for sample in samples[:1]:
|
||||
assert sample[0]["role"] == "user"
|
||||
prompt = sample[0]["content"]
|
||||
assert sample[1]["role"] == "assistant"
|
||||
ref = sample[1]["content"]
|
||||
|
||||
state = generate(prompt, lora_path)
|
||||
print("================================")
|
||||
print(ref)
|
||||
print("--------------------------------")
|
||||
# print(state["ans"])
|
||||
print(state)
|
||||
print()
|
||||
Reference in New Issue
Block a user