[Minor] Many cleanup (#1357)
This commit is contained in:
@@ -1,8 +1,3 @@
|
|||||||
## Download data
|
|
||||||
```
|
|
||||||
bash download_data.sh
|
|
||||||
```
|
|
||||||
|
|
||||||
## Run benchmark
|
## Run benchmark
|
||||||
|
|
||||||
### Benchmark sglang
|
### Benchmark sglang
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ import numpy as np
|
|||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from sglang.test.test_utils import add_common_other_args_and_parse, get_call_generate
|
from sglang.test.test_utils import add_common_other_args_and_parse, get_call_generate
|
||||||
from sglang.utils import dump_state_text, read_jsonl
|
from sglang.utils import download_and_cache_file, dump_state_text, read_jsonl
|
||||||
|
|
||||||
INVALID = -9999999
|
INVALID = -9999999
|
||||||
|
|
||||||
@@ -41,24 +41,28 @@ def get_answer_value(answer_str):
|
|||||||
|
|
||||||
|
|
||||||
def main(args):
|
def main(args):
|
||||||
lines = read_jsonl(args.data_path)
|
# Select backend
|
||||||
|
call_generate = get_call_generate(args)
|
||||||
|
|
||||||
|
# Read data
|
||||||
|
url = "https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl"
|
||||||
|
filename = download_and_cache_file(url)
|
||||||
|
lines = list(read_jsonl(filename))
|
||||||
|
|
||||||
# Construct prompts
|
# Construct prompts
|
||||||
k = args.num_shot
|
num_questions = args.num_questions
|
||||||
few_shot_examples = get_few_shot_examples(lines, k)
|
num_shots = args.num_shots
|
||||||
|
few_shot_examples = get_few_shot_examples(lines, num_shots)
|
||||||
|
|
||||||
questions = []
|
questions = []
|
||||||
labels = []
|
labels = []
|
||||||
for i in range(len(lines[: args.num_questions])):
|
for i in range(len(lines[:num_questions])):
|
||||||
questions.append(get_one_example(lines, i, False))
|
questions.append(get_one_example(lines, i, False))
|
||||||
labels.append(get_answer_value(lines[i]["answer"]))
|
labels.append(get_answer_value(lines[i]["answer"]))
|
||||||
assert all(l != INVALID for l in labels)
|
assert all(l != INVALID for l in labels)
|
||||||
|
|
||||||
states = [None] * len(labels)
|
states = [None] * len(labels)
|
||||||
|
|
||||||
# Select backend
|
|
||||||
call_generate = get_call_generate(args)
|
|
||||||
|
|
||||||
# Run requests
|
# Run requests
|
||||||
if args.backend != "lmql":
|
if args.backend != "lmql":
|
||||||
# Use thread pool
|
# Use thread pool
|
||||||
@@ -113,11 +117,13 @@ def main(args):
|
|||||||
# Compute accuracy
|
# Compute accuracy
|
||||||
acc = np.mean(np.array(preds) == np.array(labels))
|
acc = np.mean(np.array(preds) == np.array(labels))
|
||||||
invalid = np.mean(np.array(preds) == INVALID)
|
invalid = np.mean(np.array(preds) == INVALID)
|
||||||
print(f"Latency: {latency:.3f}")
|
|
||||||
print(f"Invalid: {invalid:.3f}")
|
|
||||||
print(f"Accuracy: {acc:.3f}")
|
|
||||||
|
|
||||||
# Write results
|
# Print results
|
||||||
|
print(f"Accuracy: {acc:.3f}")
|
||||||
|
print(f"Invalid: {invalid:.3f}")
|
||||||
|
print(f"Latency: {latency:.3f} s")
|
||||||
|
|
||||||
|
# Dump results
|
||||||
dump_state_text(f"tmp_output_{args.backend}.txt", states)
|
dump_state_text(f"tmp_output_{args.backend}.txt", states)
|
||||||
|
|
||||||
with open(args.result_file, "a") as fout:
|
with open(args.result_file, "a") as fout:
|
||||||
@@ -138,7 +144,7 @@ def main(args):
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("--num-shot", type=int, default=5)
|
parser.add_argument("--num-shots", type=int, default=5)
|
||||||
parser.add_argument("--data-path", type=str, default="test.jsonl")
|
parser.add_argument("--data-path", type=str, default="test.jsonl")
|
||||||
parser.add_argument("--num-questions", type=int, default=200)
|
parser.add_argument("--num-questions", type=int, default=200)
|
||||||
args = add_common_other_args_and_parse(parser)
|
args = add_common_other_args_and_parse(parser)
|
||||||
|
|||||||
@@ -6,11 +6,12 @@ import time
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
from sglang.api import set_default_backend
|
||||||
from sglang.test.test_utils import (
|
from sglang.test.test_utils import (
|
||||||
add_common_sglang_args_and_parse,
|
add_common_sglang_args_and_parse,
|
||||||
select_sglang_backend,
|
select_sglang_backend,
|
||||||
)
|
)
|
||||||
from sglang.utils import dump_state_text, read_jsonl
|
from sglang.utils import download_and_cache_file, dump_state_text, read_jsonl
|
||||||
|
|
||||||
INVALID = -9999999
|
INVALID = -9999999
|
||||||
|
|
||||||
@@ -41,15 +42,22 @@ def get_answer_value(answer_str):
|
|||||||
|
|
||||||
|
|
||||||
def main(args):
|
def main(args):
|
||||||
lines = read_jsonl(args.data_path)
|
# Select backend
|
||||||
|
set_default_backend(select_sglang_backend(args))
|
||||||
|
|
||||||
|
# Read data
|
||||||
|
url = "https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl"
|
||||||
|
filename = download_and_cache_file(url)
|
||||||
|
lines = list(read_jsonl(filename))
|
||||||
|
|
||||||
# Construct prompts
|
# Construct prompts
|
||||||
k = args.num_shot
|
num_questions = args.num_questions
|
||||||
few_shot_examples = get_few_shot_examples(lines, k)
|
num_shots = args.num_shots
|
||||||
|
few_shot_examples = get_few_shot_examples(lines, num_shots)
|
||||||
|
|
||||||
questions = []
|
questions = []
|
||||||
labels = []
|
labels = []
|
||||||
for i in range(len(lines[: args.num_questions])):
|
for i in range(len(lines[:num_questions])):
|
||||||
questions.append(get_one_example(lines, i, False))
|
questions.append(get_one_example(lines, i, False))
|
||||||
labels.append(get_answer_value(lines[i]["answer"]))
|
labels.append(get_answer_value(lines[i]["answer"]))
|
||||||
assert all(l != INVALID for l in labels)
|
assert all(l != INVALID for l in labels)
|
||||||
@@ -72,15 +80,11 @@ def main(args):
|
|||||||
########## SGL Program End ##########
|
########## SGL Program End ##########
|
||||||
#####################################
|
#####################################
|
||||||
|
|
||||||
# Select backend
|
|
||||||
backend = select_sglang_backend(args)
|
|
||||||
|
|
||||||
# Run requests
|
# Run requests
|
||||||
tic = time.time()
|
tic = time.time()
|
||||||
states = few_shot_gsm8k.run_batch(
|
states = few_shot_gsm8k.run_batch(
|
||||||
arguments,
|
arguments,
|
||||||
temperature=0,
|
temperature=0,
|
||||||
backend=backend,
|
|
||||||
num_threads=args.parallel,
|
num_threads=args.parallel,
|
||||||
progress_bar=True,
|
progress_bar=True,
|
||||||
)
|
)
|
||||||
@@ -96,11 +100,20 @@ def main(args):
|
|||||||
# Compute accuracy
|
# Compute accuracy
|
||||||
acc = np.mean(np.array(preds) == np.array(labels))
|
acc = np.mean(np.array(preds) == np.array(labels))
|
||||||
invalid = np.mean(np.array(preds) == INVALID)
|
invalid = np.mean(np.array(preds) == INVALID)
|
||||||
print(f"Latency: {latency:.3f}")
|
|
||||||
print(f"Invalid: {invalid:.3f}")
|
|
||||||
print(f"Accuracy: {acc:.3f}")
|
|
||||||
|
|
||||||
# Write results
|
# Compute speed
|
||||||
|
num_output_tokens = sum(
|
||||||
|
s.get_meta_info("answer")["completion_tokens"] for s in states
|
||||||
|
)
|
||||||
|
output_throughput = num_output_tokens / latency
|
||||||
|
|
||||||
|
# Print results
|
||||||
|
print(f"Accuracy: {acc:.3f}")
|
||||||
|
print(f"Invalid: {invalid:.3f}")
|
||||||
|
print(f"Latency: {latency:.3f} s")
|
||||||
|
print(f"Output throughput: {output_throughput:.3f} token/s")
|
||||||
|
|
||||||
|
# Dump results
|
||||||
dump_state_text(f"tmp_output_{args.backend}.txt", states)
|
dump_state_text(f"tmp_output_{args.backend}.txt", states)
|
||||||
|
|
||||||
with open(args.result_file, "a") as fout:
|
with open(args.result_file, "a") as fout:
|
||||||
@@ -121,7 +134,7 @@ def main(args):
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("--num-shot", type=int, default=5)
|
parser.add_argument("--num-shots", type=int, default=5)
|
||||||
parser.add_argument("--data-path", type=str, default="test.jsonl")
|
parser.add_argument("--data-path", type=str, default="test.jsonl")
|
||||||
parser.add_argument("--num-questions", type=int, default=200)
|
parser.add_argument("--num-questions", type=int, default=200)
|
||||||
args = add_common_sglang_args_and_parse(parser)
|
args = add_common_sglang_args_and_parse(parser)
|
||||||
|
|||||||
@@ -1,2 +0,0 @@
|
|||||||
wget https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/train.jsonl
|
|
||||||
wget https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl
|
|
||||||
@@ -1,8 +1,3 @@
|
|||||||
## Download data
|
|
||||||
```
|
|
||||||
wget https://raw.githubusercontent.com/rowanz/hellaswag/master/data/hellaswag_val.jsonl
|
|
||||||
```
|
|
||||||
|
|
||||||
## Run benchmark
|
## Run benchmark
|
||||||
|
|
||||||
### Benchmark sglang
|
### Benchmark sglang
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ import numpy as np
|
|||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from sglang.test.test_utils import add_common_other_args_and_parse, get_call_select
|
from sglang.test.test_utils import add_common_other_args_and_parse, get_call_select
|
||||||
from sglang.utils import read_jsonl
|
from sglang.utils import download_and_cache_file, read_jsonl
|
||||||
|
|
||||||
|
|
||||||
def get_one_example(lines, i, include_answer):
|
def get_one_example(lines, i, include_answer):
|
||||||
@@ -26,25 +26,29 @@ def get_few_shot_examples(lines, k):
|
|||||||
|
|
||||||
|
|
||||||
def main(args):
|
def main(args):
|
||||||
lines = read_jsonl(args.data_path)
|
# Select backend
|
||||||
|
call_select = get_call_select(args)
|
||||||
|
|
||||||
|
# Read data
|
||||||
|
url = "https://raw.githubusercontent.com/rowanz/hellaswag/master/data/hellaswag_val.jsonl"
|
||||||
|
filename = download_and_cache_file(url)
|
||||||
|
lines = list(read_jsonl(filename))
|
||||||
|
|
||||||
# Construct prompts
|
# Construct prompts
|
||||||
k = args.num_shot
|
num_questions = args.num_questions
|
||||||
few_shot_examples = get_few_shot_examples(lines, k)
|
num_shots = args.num_shots
|
||||||
|
few_shot_examples = get_few_shot_examples(lines, num_shots)
|
||||||
|
|
||||||
questions = []
|
questions = []
|
||||||
choices = []
|
choices = []
|
||||||
labels = []
|
labels = []
|
||||||
for i in range(len(lines[: args.num_questions])):
|
for i in range(len(lines[:num_questions])):
|
||||||
questions.append(get_one_example(lines, i, False))
|
questions.append(get_one_example(lines, i, False))
|
||||||
choices.append(lines[i]["endings"])
|
choices.append(lines[i]["endings"])
|
||||||
labels.append(lines[i]["label"])
|
labels.append(lines[i]["label"])
|
||||||
|
|
||||||
preds = [None] * len(labels)
|
preds = [None] * len(labels)
|
||||||
|
|
||||||
# Select backend
|
|
||||||
call_select = get_call_select(args)
|
|
||||||
|
|
||||||
# Run requests
|
# Run requests
|
||||||
if args.backend != "lmql":
|
if args.backend != "lmql":
|
||||||
# Use thread pool
|
# Use thread pool
|
||||||
@@ -65,7 +69,6 @@ def main(args):
|
|||||||
total=len(questions),
|
total=len(questions),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# Use asyncio
|
# Use asyncio
|
||||||
async def batched_call(batch_size):
|
async def batched_call(batch_size):
|
||||||
@@ -108,7 +111,7 @@ def main(args):
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("--num-shot", type=int, default=20)
|
parser.add_argument("--num-shots", type=int, default=20)
|
||||||
parser.add_argument("--data-path", type=str, default="hellaswag_val.jsonl")
|
parser.add_argument("--data-path", type=str, default="hellaswag_val.jsonl")
|
||||||
parser.add_argument("--num-questions", type=int, default=200)
|
parser.add_argument("--num-questions", type=int, default=200)
|
||||||
args = add_common_other_args_and_parse(parser)
|
args = add_common_other_args_and_parse(parser)
|
||||||
|
|||||||
@@ -4,11 +4,12 @@ import time
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
from sglang.api import set_default_backend
|
||||||
from sglang.test.test_utils import (
|
from sglang.test.test_utils import (
|
||||||
add_common_sglang_args_and_parse,
|
add_common_sglang_args_and_parse,
|
||||||
select_sglang_backend,
|
select_sglang_backend,
|
||||||
)
|
)
|
||||||
from sglang.utils import read_jsonl
|
from sglang.utils import download_and_cache_file, read_jsonl
|
||||||
|
|
||||||
|
|
||||||
def get_one_example(lines, i, include_answer):
|
def get_one_example(lines, i, include_answer):
|
||||||
@@ -26,16 +27,23 @@ def get_few_shot_examples(lines, k):
|
|||||||
|
|
||||||
|
|
||||||
def main(args):
|
def main(args):
|
||||||
lines = read_jsonl(args.data_path)
|
# Select backend
|
||||||
|
set_default_backend(select_sglang_backend(args))
|
||||||
|
|
||||||
|
# Read data
|
||||||
|
url = "https://raw.githubusercontent.com/rowanz/hellaswag/master/data/hellaswag_val.jsonl"
|
||||||
|
filename = download_and_cache_file(url)
|
||||||
|
lines = list(read_jsonl(filename))
|
||||||
|
|
||||||
# Construct prompts
|
# Construct prompts
|
||||||
k = args.num_shot
|
num_questions = args.num_questions
|
||||||
few_shot_examples = get_few_shot_examples(lines, k)
|
num_shots = args.num_shots
|
||||||
|
few_shot_examples = get_few_shot_examples(lines, num_shots)
|
||||||
|
|
||||||
questions = []
|
questions = []
|
||||||
choices = []
|
choices = []
|
||||||
labels = []
|
labels = []
|
||||||
for i in range(len(lines[: args.num_questions])):
|
for i in range(len(lines[:num_questions])):
|
||||||
questions.append(get_one_example(lines, i, False))
|
questions.append(get_one_example(lines, i, False))
|
||||||
choices.append(lines[i]["endings"])
|
choices.append(lines[i]["endings"])
|
||||||
labels.append(lines[i]["label"])
|
labels.append(lines[i]["label"])
|
||||||
@@ -56,15 +64,11 @@ def main(args):
|
|||||||
########## SGL Program End ##########
|
########## SGL Program End ##########
|
||||||
#####################################
|
#####################################
|
||||||
|
|
||||||
# Select backend
|
|
||||||
backend = select_sglang_backend(args)
|
|
||||||
|
|
||||||
# Run requests
|
# Run requests
|
||||||
tic = time.time()
|
tic = time.time()
|
||||||
rets = few_shot_hellaswag.run_batch(
|
rets = few_shot_hellaswag.run_batch(
|
||||||
arguments,
|
arguments,
|
||||||
temperature=0,
|
temperature=0,
|
||||||
backend=backend,
|
|
||||||
num_threads=args.parallel,
|
num_threads=args.parallel,
|
||||||
progress_bar=True,
|
progress_bar=True,
|
||||||
)
|
)
|
||||||
@@ -95,7 +99,7 @@ def main(args):
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("--num-shot", type=int, default=20)
|
parser.add_argument("--num-shots", type=int, default=20)
|
||||||
parser.add_argument("--data-path", type=str, default="hellaswag_val.jsonl")
|
parser.add_argument("--data-path", type=str, default="hellaswag_val.jsonl")
|
||||||
parser.add_argument("--num-questions", type=int, default=200)
|
parser.add_argument("--num-questions", type=int, default=200)
|
||||||
args = add_common_sglang_args_and_parse(parser)
|
args = add_common_sglang_args_and_parse(parser)
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ python3 srt_example_llava_v.py
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import csv
|
import csv
|
||||||
|
import json
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
|
|
||||||
@@ -223,7 +224,7 @@ if __name__ == "__main__":
|
|||||||
tokenizer_path=tokenizer_path,
|
tokenizer_path=tokenizer_path,
|
||||||
port=cur_port,
|
port=cur_port,
|
||||||
additional_ports=[cur_port + 1, cur_port + 2, cur_port + 3, cur_port + 4],
|
additional_ports=[cur_port + 1, cur_port + 2, cur_port + 3, cur_port + 4],
|
||||||
model_override_args=model_override_args,
|
json_model_override_args=json.dumps(model_override_args),
|
||||||
tp_size=1,
|
tp_size=1,
|
||||||
)
|
)
|
||||||
sgl.set_default_backend(runtime)
|
sgl.set_default_backend(runtime)
|
||||||
|
|||||||
@@ -298,34 +298,41 @@ class BenchmarkMetrics:
|
|||||||
median_e2e_latency_ms: float
|
median_e2e_latency_ms: float
|
||||||
|
|
||||||
|
|
||||||
default_sharegpt_path = "ShareGPT_V3_unfiltered_cleaned_split.json"
|
SHAREGPT_URL = "https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json"
|
||||||
|
|
||||||
|
|
||||||
def download_sharegpt_dataset(path):
|
def download_and_cache_file(url: str, filename: Optional[str] = None):
|
||||||
url = "https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json"
|
"""Read and cache a file from a url."""
|
||||||
|
if filename is None:
|
||||||
|
filename = os.path.join("/tmp", url.split("/")[-1])
|
||||||
|
|
||||||
print(f"Downloading dataset from {url}")
|
# Check if the cache file already exists
|
||||||
try:
|
if os.path.exists(filename):
|
||||||
response = requests.get(url, stream=True)
|
return filename
|
||||||
response.raise_for_status()
|
|
||||||
|
|
||||||
total_size = int(response.headers.get("content-length", 0))
|
print(f"Downloading from {url} to {filename}")
|
||||||
block_size = 8192
|
|
||||||
|
|
||||||
with open(path, "wb") as f, tqdm(
|
# Stream the response to show the progress bar
|
||||||
desc="Downloading",
|
response = requests.get(url, stream=True)
|
||||||
total=total_size,
|
response.raise_for_status() # Check for request errors
|
||||||
unit="iB",
|
|
||||||
unit_scale=True,
|
|
||||||
unit_divisor=1024,
|
|
||||||
) as progress_bar:
|
|
||||||
for data in response.iter_content(block_size):
|
|
||||||
size = f.write(data)
|
|
||||||
progress_bar.update(size)
|
|
||||||
|
|
||||||
print(f"Dataset downloaded and saved to {path}")
|
# Total size of the file in bytes
|
||||||
except requests.RequestException as e:
|
total_size = int(response.headers.get("content-length", 0))
|
||||||
raise Exception(f"Failed to download dataset: {e}")
|
chunk_size = 1024 # Download in chunks of 1KB
|
||||||
|
|
||||||
|
# Use tqdm to display the progress bar
|
||||||
|
with open(filename, "wb") as f, tqdm(
|
||||||
|
desc=filename,
|
||||||
|
total=total_size,
|
||||||
|
unit="B",
|
||||||
|
unit_scale=True,
|
||||||
|
unit_divisor=1024,
|
||||||
|
) as bar:
|
||||||
|
for chunk in response.iter_content(chunk_size=chunk_size):
|
||||||
|
f.write(chunk)
|
||||||
|
bar.update(len(chunk))
|
||||||
|
|
||||||
|
return filename
|
||||||
|
|
||||||
|
|
||||||
def sample_sharegpt_requests(
|
def sample_sharegpt_requests(
|
||||||
@@ -338,13 +345,8 @@ def sample_sharegpt_requests(
|
|||||||
raise ValueError("output_len too small")
|
raise ValueError("output_len too small")
|
||||||
|
|
||||||
# Download sharegpt if necessary
|
# Download sharegpt if necessary
|
||||||
if not os.path.isfile(dataset_path) and not os.path.isfile(default_sharegpt_path):
|
if not os.path.isfile(dataset_path):
|
||||||
download_sharegpt_dataset(default_sharegpt_path)
|
dataset_path = download_and_cache_file(SHAREGPT_URL)
|
||||||
dataset_path = default_sharegpt_path
|
|
||||||
else:
|
|
||||||
dataset_path = (
|
|
||||||
dataset_path if os.path.isfile(dataset_path) else default_sharegpt_path
|
|
||||||
)
|
|
||||||
|
|
||||||
# Load the dataset.
|
# Load the dataset.
|
||||||
with open(dataset_path) as f:
|
with open(dataset_path) as f:
|
||||||
@@ -412,15 +414,8 @@ def sample_random_requests(
|
|||||||
# Sample token ids from ShareGPT and repeat/truncate them to satisfy the input_lens
|
# Sample token ids from ShareGPT and repeat/truncate them to satisfy the input_lens
|
||||||
|
|
||||||
# Download sharegpt if necessary
|
# Download sharegpt if necessary
|
||||||
if not os.path.isfile(dataset_path) and not os.path.isfile(
|
if not os.path.isfile(dataset_path):
|
||||||
default_sharegpt_path
|
dataset_path = download_and_cache_file(SHAREGPT_URL)
|
||||||
):
|
|
||||||
download_sharegpt_dataset(default_sharegpt_path)
|
|
||||||
dataset_path = default_sharegpt_path
|
|
||||||
else:
|
|
||||||
dataset_path = (
|
|
||||||
dataset_path if os.path.isfile(dataset_path) else default_sharegpt_path
|
|
||||||
)
|
|
||||||
|
|
||||||
# Load the dataset.
|
# Load the dataset.
|
||||||
with open(dataset_path) as f:
|
with open(dataset_path) as f:
|
||||||
|
|||||||
@@ -9,10 +9,9 @@ from sglang.srt.utils import kill_child_process
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
server_args = prepare_server_args(sys.argv[1:])
|
server_args = prepare_server_args(sys.argv[1:])
|
||||||
model_override_args = server_args.json_model_override_args
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
launch_server(server_args, model_override_args=model_override_args)
|
launch_server(server_args)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise e
|
raise e
|
||||||
finally:
|
finally:
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
"""Launch the inference server for Llava-video model."""
|
"""Launch the inference server for Llava-video model."""
|
||||||
|
|
||||||
|
import json
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
from sglang.srt.server import launch_server, prepare_server_args
|
from sglang.srt.server import launch_server, prepare_server_args
|
||||||
@@ -19,5 +20,6 @@ if __name__ == "__main__":
|
|||||||
model_override_args["model_max_length"] = 4096 * 2
|
model_override_args["model_max_length"] = 4096 * 2
|
||||||
if "34b" in server_args.model_path.lower():
|
if "34b" in server_args.model_path.lower():
|
||||||
model_override_args["image_token_index"] = 64002
|
model_override_args["image_token_index"] = 64002
|
||||||
|
server_args.json_model_override_args = json.dumps(model_override_args)
|
||||||
|
|
||||||
launch_server(server_args, model_override_args, None)
|
launch_server(server_args)
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ limitations under the License.
|
|||||||
"""Cache for the compressed finite state machine."""
|
"""Cache for the compressed finite state machine."""
|
||||||
|
|
||||||
from outlines.fsm.json_schema import build_regex_from_schema
|
from outlines.fsm.json_schema import build_regex_from_schema
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
from sglang.srt.constrained import RegexGuide, TransformerTokenizer
|
from sglang.srt.constrained import RegexGuide, TransformerTokenizer
|
||||||
from sglang.srt.constrained.base_tool_cache import BaseToolCache
|
from sglang.srt.constrained.base_tool_cache import BaseToolCache
|
||||||
@@ -28,12 +29,9 @@ class FSMCache(BaseToolCache):
|
|||||||
tokenizer_args_dict,
|
tokenizer_args_dict,
|
||||||
enable=True,
|
enable=True,
|
||||||
skip_tokenizer_init=False,
|
skip_tokenizer_init=False,
|
||||||
json_schema_mode=False,
|
|
||||||
):
|
):
|
||||||
super().__init__(enable=enable)
|
super().__init__(enable=enable)
|
||||||
|
|
||||||
self.json_schema_mode = json_schema_mode
|
|
||||||
|
|
||||||
if (
|
if (
|
||||||
skip_tokenizer_init
|
skip_tokenizer_init
|
||||||
or tokenizer_path.endswith(".json")
|
or tokenizer_path.endswith(".json")
|
||||||
@@ -42,44 +40,37 @@ class FSMCache(BaseToolCache):
|
|||||||
# Do not support TiktokenTokenizer or SentencePieceTokenizer
|
# Do not support TiktokenTokenizer or SentencePieceTokenizer
|
||||||
return
|
return
|
||||||
|
|
||||||
from importlib.metadata import version
|
tokenizer_args_dict.setdefault("padding_side", "left")
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, **tokenizer_args_dict)
|
||||||
|
try:
|
||||||
|
self.outlines_tokenizer = TransformerTokenizer(tokenizer)
|
||||||
|
except AttributeError:
|
||||||
|
# FIXME: tmp fix for chatglm2 & chatglm3 (pad_token_id=0)
|
||||||
|
origin_pad_token_id = tokenizer.pad_token_id
|
||||||
|
|
||||||
if version("outlines") >= "0.0.35":
|
def fset(self, value):
|
||||||
from transformers import AutoTokenizer
|
self._value = value
|
||||||
|
|
||||||
tokenizer_args_dict.setdefault("padding_side", "left")
|
type(tokenizer).pad_token_id = property(
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
fget=type(tokenizer).pad_token_id.fget, fset=fset
|
||||||
tokenizer_path, **tokenizer_args_dict
|
|
||||||
)
|
)
|
||||||
try:
|
self.outlines_tokenizer = TransformerTokenizer(tokenizer)
|
||||||
self.outlines_tokenizer = TransformerTokenizer(tokenizer)
|
self.outlines_tokenizer.tokenizer.pad_token_id = origin_pad_token_id
|
||||||
except AttributeError:
|
self.outlines_tokenizer.pad_token_id = origin_pad_token_id
|
||||||
# FIXME: tmp fix for chatglm2 & chatglm3 (pad_token_id=0)
|
self.outlines_tokenizer.pad_token = (
|
||||||
origin_pad_token_id = tokenizer.pad_token_id
|
self.outlines_tokenizer.tokenizer.pad_token
|
||||||
|
)
|
||||||
def fset(self, value):
|
self.outlines_tokenizer.vocabulary = (
|
||||||
self._value = value
|
self.outlines_tokenizer.tokenizer.get_vocab()
|
||||||
|
|
||||||
type(tokenizer).pad_token_id = property(
|
|
||||||
fget=type(tokenizer).pad_token_id.fget, fset=fset
|
|
||||||
)
|
|
||||||
self.outlines_tokenizer = TransformerTokenizer(tokenizer)
|
|
||||||
self.outlines_tokenizer.tokenizer.pad_token_id = origin_pad_token_id
|
|
||||||
self.outlines_tokenizer.pad_token_id = origin_pad_token_id
|
|
||||||
self.outlines_tokenizer.pad_token = (
|
|
||||||
self.outlines_tokenizer.tokenizer.pad_token
|
|
||||||
)
|
|
||||||
self.outlines_tokenizer.vocabulary = (
|
|
||||||
self.outlines_tokenizer.tokenizer.get_vocab()
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.outlines_tokenizer = TransformerTokenizer(
|
|
||||||
tokenizer_path, **tokenizer_args_dict
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def init_value(self, value):
|
def init_value(self, key):
|
||||||
if self.json_schema_mode:
|
key_type, key_string = key
|
||||||
regex = build_regex_from_schema(value, whitespace_pattern=r"[\n\t ]*")
|
if key_type == "json":
|
||||||
return RegexGuide(regex, self.outlines_tokenizer), regex
|
regex = build_regex_from_schema(key_string, whitespace_pattern=r"[\n\t ]*")
|
||||||
|
elif key_type == "regex":
|
||||||
|
regex = key_string
|
||||||
else:
|
else:
|
||||||
return RegexGuide(value, self.outlines_tokenizer)
|
raise ValueError(f"Invalid key_type: {key_type}")
|
||||||
|
|
||||||
|
return RegexGuide(regex, self.outlines_tokenizer), regex
|
||||||
|
|||||||
@@ -71,12 +71,10 @@ class ControllerMulti:
|
|||||||
self,
|
self,
|
||||||
server_args: ServerArgs,
|
server_args: ServerArgs,
|
||||||
port_args: PortArgs,
|
port_args: PortArgs,
|
||||||
model_override_args,
|
|
||||||
):
|
):
|
||||||
# Parse args
|
# Parse args
|
||||||
self.server_args = server_args
|
self.server_args = server_args
|
||||||
self.port_args = port_args
|
self.port_args = port_args
|
||||||
self.model_override_args = model_override_args
|
|
||||||
self.load_balance_method = LoadBalanceMethod.from_str(
|
self.load_balance_method = LoadBalanceMethod.from_str(
|
||||||
server_args.load_balance_method
|
server_args.load_balance_method
|
||||||
)
|
)
|
||||||
@@ -114,7 +112,6 @@ class ControllerMulti:
|
|||||||
self.server_args,
|
self.server_args,
|
||||||
self.port_args,
|
self.port_args,
|
||||||
pipe_controller_writer,
|
pipe_controller_writer,
|
||||||
self.model_override_args,
|
|
||||||
True,
|
True,
|
||||||
gpu_ids,
|
gpu_ids,
|
||||||
dp_worker_id,
|
dp_worker_id,
|
||||||
@@ -189,14 +186,13 @@ def start_controller_process(
|
|||||||
server_args: ServerArgs,
|
server_args: ServerArgs,
|
||||||
port_args: PortArgs,
|
port_args: PortArgs,
|
||||||
pipe_writer,
|
pipe_writer,
|
||||||
model_override_args: dict,
|
|
||||||
):
|
):
|
||||||
"""Start a controller process."""
|
"""Start a controller process."""
|
||||||
|
|
||||||
configure_logger(server_args)
|
configure_logger(server_args)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
controller = ControllerMulti(server_args, port_args, model_override_args)
|
controller = ControllerMulti(server_args, port_args)
|
||||||
except Exception:
|
except Exception:
|
||||||
pipe_writer.send(get_exception_traceback())
|
pipe_writer.send(get_exception_traceback())
|
||||||
raise
|
raise
|
||||||
|
|||||||
@@ -40,7 +40,6 @@ class ControllerSingle:
|
|||||||
self,
|
self,
|
||||||
server_args: ServerArgs,
|
server_args: ServerArgs,
|
||||||
port_args: PortArgs,
|
port_args: PortArgs,
|
||||||
model_override_args: dict,
|
|
||||||
gpu_ids: List[int],
|
gpu_ids: List[int],
|
||||||
is_data_parallel_worker: bool,
|
is_data_parallel_worker: bool,
|
||||||
dp_worker_id: int,
|
dp_worker_id: int,
|
||||||
@@ -76,7 +75,6 @@ class ControllerSingle:
|
|||||||
tp_rank_range,
|
tp_rank_range,
|
||||||
server_args,
|
server_args,
|
||||||
port_args.nccl_ports[dp_worker_id],
|
port_args.nccl_ports[dp_worker_id],
|
||||||
model_override_args,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Launch tp rank 0
|
# Launch tp rank 0
|
||||||
@@ -85,7 +83,6 @@ class ControllerSingle:
|
|||||||
0,
|
0,
|
||||||
server_args,
|
server_args,
|
||||||
port_args.nccl_ports[dp_worker_id],
|
port_args.nccl_ports[dp_worker_id],
|
||||||
model_override_args,
|
|
||||||
)
|
)
|
||||||
self.tp_cpu_group = self.tp_server.model_runner.tp_group.cpu_group
|
self.tp_cpu_group = self.tp_server.model_runner.tp_group.cpu_group
|
||||||
|
|
||||||
@@ -126,7 +123,6 @@ def start_controller_process(
|
|||||||
server_args: ServerArgs,
|
server_args: ServerArgs,
|
||||||
port_args: PortArgs,
|
port_args: PortArgs,
|
||||||
pipe_writer: multiprocessing.connection.Connection,
|
pipe_writer: multiprocessing.connection.Connection,
|
||||||
model_override_args: dict,
|
|
||||||
is_data_parallel_worker: bool = False,
|
is_data_parallel_worker: bool = False,
|
||||||
gpu_ids: List[int] = None,
|
gpu_ids: List[int] = None,
|
||||||
dp_worker_id: int = None,
|
dp_worker_id: int = None,
|
||||||
@@ -149,7 +145,6 @@ def start_controller_process(
|
|||||||
controller = ControllerSingle(
|
controller = ControllerSingle(
|
||||||
server_args,
|
server_args,
|
||||||
port_args,
|
port_args,
|
||||||
model_override_args,
|
|
||||||
gpu_ids,
|
gpu_ids,
|
||||||
is_data_parallel_worker,
|
is_data_parallel_worker,
|
||||||
dp_worker_id,
|
dp_worker_id,
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ limitations under the License.
|
|||||||
import asyncio
|
import asyncio
|
||||||
import concurrent.futures
|
import concurrent.futures
|
||||||
import dataclasses
|
import dataclasses
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
import multiprocessing as mp
|
import multiprocessing as mp
|
||||||
import os
|
import os
|
||||||
@@ -77,7 +78,6 @@ class TokenizerManager:
|
|||||||
self,
|
self,
|
||||||
server_args: ServerArgs,
|
server_args: ServerArgs,
|
||||||
port_args: PortArgs,
|
port_args: PortArgs,
|
||||||
model_override_args: dict = None,
|
|
||||||
):
|
):
|
||||||
self.server_args = server_args
|
self.server_args = server_args
|
||||||
|
|
||||||
@@ -95,7 +95,7 @@ class TokenizerManager:
|
|||||||
self.hf_config = get_config(
|
self.hf_config = get_config(
|
||||||
self.model_path,
|
self.model_path,
|
||||||
trust_remote_code=server_args.trust_remote_code,
|
trust_remote_code=server_args.trust_remote_code,
|
||||||
model_override_args=model_override_args,
|
model_override_args=json.loads(server_args.json_model_override_args),
|
||||||
)
|
)
|
||||||
self.is_generation = is_generation_model(
|
self.is_generation = is_generation_model(
|
||||||
self.hf_config.architectures, self.server_args.is_embedding
|
self.hf_config.architectures, self.server_args.is_embedding
|
||||||
|
|||||||
@@ -15,13 +15,14 @@ limitations under the License.
|
|||||||
|
|
||||||
"""A tensor parallel worker."""
|
"""A tensor parallel worker."""
|
||||||
|
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
import os
|
import os
|
||||||
import pickle
|
import pickle
|
||||||
import time
|
import time
|
||||||
import warnings
|
import warnings
|
||||||
from typing import Any, List, Optional, Union
|
from typing import Any, List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed
|
import torch.distributed
|
||||||
@@ -66,6 +67,7 @@ from sglang.utils import get_exception_traceback
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# Crash on warning if we are running CI tests
|
||||||
crash_on_warning = os.getenv("SGLANG_IS_IN_CI", "false") == "true"
|
crash_on_warning = os.getenv("SGLANG_IS_IN_CI", "false") == "true"
|
||||||
|
|
||||||
|
|
||||||
@@ -76,11 +78,10 @@ class ModelTpServer:
|
|||||||
tp_rank: int,
|
tp_rank: int,
|
||||||
server_args: ServerArgs,
|
server_args: ServerArgs,
|
||||||
nccl_port: int,
|
nccl_port: int,
|
||||||
model_override_args: dict,
|
|
||||||
):
|
):
|
||||||
suppress_other_loggers()
|
suppress_other_loggers()
|
||||||
|
|
||||||
# Copy arguments
|
# Parse arguments
|
||||||
self.gpu_id = gpu_id
|
self.gpu_id = gpu_id
|
||||||
self.tp_rank = tp_rank
|
self.tp_rank = tp_rank
|
||||||
self.tp_size = server_args.tp_size
|
self.tp_size = server_args.tp_size
|
||||||
@@ -93,9 +94,8 @@ class ModelTpServer:
|
|||||||
server_args.model_path,
|
server_args.model_path,
|
||||||
server_args.trust_remote_code,
|
server_args.trust_remote_code,
|
||||||
context_length=server_args.context_length,
|
context_length=server_args.context_length,
|
||||||
model_override_args=model_override_args,
|
model_override_args=json.loads(server_args.json_model_override_args),
|
||||||
)
|
)
|
||||||
|
|
||||||
self.model_runner = ModelRunner(
|
self.model_runner = ModelRunner(
|
||||||
model_config=self.model_config,
|
model_config=self.model_config,
|
||||||
mem_fraction_static=server_args.mem_fraction_static,
|
mem_fraction_static=server_args.mem_fraction_static,
|
||||||
@@ -136,7 +136,7 @@ class ModelTpServer:
|
|||||||
self.max_total_num_tokens - 1,
|
self.max_total_num_tokens - 1,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Sync random seed
|
# Sync random seed across TP workers
|
||||||
server_args.random_seed = broadcast_recv_input(
|
server_args.random_seed = broadcast_recv_input(
|
||||||
[server_args.random_seed],
|
[server_args.random_seed],
|
||||||
self.tp_rank,
|
self.tp_rank,
|
||||||
@@ -144,7 +144,7 @@ class ModelTpServer:
|
|||||||
)[0]
|
)[0]
|
||||||
set_random_seed(server_args.random_seed)
|
set_random_seed(server_args.random_seed)
|
||||||
|
|
||||||
# Print info
|
# Print debug info
|
||||||
logger.info(
|
logger.info(
|
||||||
f"max_total_num_tokens={self.max_total_num_tokens}, "
|
f"max_total_num_tokens={self.max_total_num_tokens}, "
|
||||||
f"max_prefill_tokens={self.max_prefill_tokens}, "
|
f"max_prefill_tokens={self.max_prefill_tokens}, "
|
||||||
@@ -181,7 +181,7 @@ class ModelTpServer:
|
|||||||
self.num_generated_tokens = 0
|
self.num_generated_tokens = 0
|
||||||
self.last_stats_tic = time.time()
|
self.last_stats_tic = time.time()
|
||||||
|
|
||||||
# Chunked prefill
|
# Init chunked prefill
|
||||||
self.chunked_prefill_size = server_args.chunked_prefill_size
|
self.chunked_prefill_size = server_args.chunked_prefill_size
|
||||||
self.current_inflight_req = None
|
self.current_inflight_req = None
|
||||||
self.is_mixed_chunk = (
|
self.is_mixed_chunk = (
|
||||||
@@ -197,16 +197,6 @@ class ModelTpServer:
|
|||||||
"trust_remote_code": server_args.trust_remote_code,
|
"trust_remote_code": server_args.trust_remote_code,
|
||||||
},
|
},
|
||||||
skip_tokenizer_init=server_args.skip_tokenizer_init,
|
skip_tokenizer_init=server_args.skip_tokenizer_init,
|
||||||
json_schema_mode=False,
|
|
||||||
)
|
|
||||||
self.json_fsm_cache = FSMCache(
|
|
||||||
server_args.tokenizer_path,
|
|
||||||
{
|
|
||||||
"tokenizer_mode": server_args.tokenizer_mode,
|
|
||||||
"trust_remote_code": server_args.trust_remote_code,
|
|
||||||
},
|
|
||||||
skip_tokenizer_init=server_args.skip_tokenizer_init,
|
|
||||||
json_schema_mode=True,
|
|
||||||
)
|
)
|
||||||
self.jump_forward_cache = JumpForwardCache()
|
self.jump_forward_cache = JumpForwardCache()
|
||||||
|
|
||||||
@@ -227,11 +217,12 @@ class ModelTpServer:
|
|||||||
try:
|
try:
|
||||||
# Recv requests
|
# Recv requests
|
||||||
for recv_req in recv_reqs:
|
for recv_req in recv_reqs:
|
||||||
if isinstance(
|
if isinstance(recv_req, TokenizedGenerateReqInput):
|
||||||
recv_req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput)
|
|
||||||
):
|
|
||||||
self.handle_generate_request(recv_req)
|
self.handle_generate_request(recv_req)
|
||||||
self.do_not_get_new_batch = False
|
self.do_not_get_new_batch = False
|
||||||
|
elif isinstance(recv_req, TokenizedEmbeddingReqInput):
|
||||||
|
self.handle_embedding_request(recv_req)
|
||||||
|
self.do_not_get_new_batch = False
|
||||||
elif isinstance(recv_req, FlushCacheReq):
|
elif isinstance(recv_req, FlushCacheReq):
|
||||||
self.flush_cache()
|
self.flush_cache()
|
||||||
elif isinstance(recv_req, AbortReq):
|
elif isinstance(recv_req, AbortReq):
|
||||||
@@ -331,57 +322,56 @@ class ModelTpServer:
|
|||||||
|
|
||||||
def handle_generate_request(
|
def handle_generate_request(
|
||||||
self,
|
self,
|
||||||
recv_req: Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput],
|
recv_req: TokenizedGenerateReqInput,
|
||||||
):
|
):
|
||||||
req = Req(recv_req.rid, recv_req.input_text, recv_req.input_ids)
|
req = Req(recv_req.rid, recv_req.input_text, recv_req.input_ids)
|
||||||
req.tokenizer = self.tokenizer
|
req.tokenizer = self.tokenizer
|
||||||
req.sampling_params = recv_req.sampling_params
|
req.sampling_params = recv_req.sampling_params
|
||||||
if self.model_runner.is_generation:
|
req.pixel_values = recv_req.pixel_values
|
||||||
req.pixel_values = recv_req.pixel_values
|
if req.pixel_values is not None:
|
||||||
if req.pixel_values is not None:
|
# Use image hash as fake token_ids, which is then used
|
||||||
# Use image hash as fake token_ids, which is then used
|
# for prefix matching
|
||||||
# for prefix matching
|
image_hash = hash(tuple(recv_req.image_hashes))
|
||||||
image_hash = hash(tuple(recv_req.image_hashes))
|
req.pad_value = [
|
||||||
req.pad_value = [
|
(image_hash) % self.model_config.vocab_size,
|
||||||
(image_hash) % self.model_config.vocab_size,
|
(image_hash >> 16) % self.model_config.vocab_size,
|
||||||
(image_hash >> 16) % self.model_config.vocab_size,
|
(image_hash >> 32) % self.model_config.vocab_size,
|
||||||
(image_hash >> 32) % self.model_config.vocab_size,
|
(image_hash >> 64) % self.model_config.vocab_size,
|
||||||
(image_hash >> 64) % self.model_config.vocab_size,
|
]
|
||||||
]
|
req.image_sizes = recv_req.image_sizes
|
||||||
req.image_sizes = recv_req.image_sizes
|
(
|
||||||
(
|
req.origin_input_ids,
|
||||||
req.origin_input_ids,
|
req.image_offsets,
|
||||||
req.image_offsets,
|
) = self.model_runner.model.pad_input_ids(
|
||||||
) = self.model_runner.model.pad_input_ids(
|
req.origin_input_ids_unpadded,
|
||||||
req.origin_input_ids_unpadded,
|
req.pad_value,
|
||||||
req.pad_value,
|
req.pixel_values,
|
||||||
req.pixel_values,
|
req.image_sizes,
|
||||||
req.image_sizes,
|
)
|
||||||
)
|
# Only when pixel values is not None we have modalities
|
||||||
# Only when pixel values is not None we have modalities
|
req.modalities = recv_req.modalites
|
||||||
req.modalities = recv_req.modalites
|
req.return_logprob = recv_req.return_logprob
|
||||||
req.return_logprob = recv_req.return_logprob
|
req.logprob_start_len = recv_req.logprob_start_len
|
||||||
req.logprob_start_len = recv_req.logprob_start_len
|
req.top_logprobs_num = recv_req.top_logprobs_num
|
||||||
req.top_logprobs_num = recv_req.top_logprobs_num
|
req.stream = recv_req.stream
|
||||||
req.stream = recv_req.stream
|
|
||||||
|
|
||||||
# Init regex fsm fron json
|
# Init regex FSM
|
||||||
|
if (
|
||||||
|
req.sampling_params.json_schema is not None
|
||||||
|
or req.sampling_params.regex is not None
|
||||||
|
):
|
||||||
if req.sampling_params.json_schema is not None:
|
if req.sampling_params.json_schema is not None:
|
||||||
req.regex_fsm, computed_regex_string = self.json_fsm_cache.query(
|
req.regex_fsm, computed_regex_string = self.regex_fsm_cache.query(
|
||||||
req.sampling_params.json_schema
|
("json", req.sampling_params.json_schema)
|
||||||
)
|
)
|
||||||
if not self.disable_regex_jump_forward:
|
|
||||||
req.jump_forward_map = self.jump_forward_cache.query(
|
|
||||||
computed_regex_string
|
|
||||||
)
|
|
||||||
|
|
||||||
# Init regex fsm
|
|
||||||
elif req.sampling_params.regex is not None:
|
elif req.sampling_params.regex is not None:
|
||||||
req.regex_fsm = self.regex_fsm_cache.query(req.sampling_params.regex)
|
req.regex_fsm, computed_regex_string = self.regex_fsm_cache.query(
|
||||||
if not self.disable_regex_jump_forward:
|
("regex", req.sampling_params.regex)
|
||||||
req.jump_forward_map = self.jump_forward_cache.query(
|
)
|
||||||
req.sampling_params.regex
|
if not self.disable_regex_jump_forward:
|
||||||
)
|
req.jump_forward_map = self.jump_forward_cache.query(
|
||||||
|
computed_regex_string
|
||||||
|
)
|
||||||
|
|
||||||
# Truncate prompts that are too long
|
# Truncate prompts that are too long
|
||||||
if len(req.origin_input_ids) >= self.max_req_input_len:
|
if len(req.origin_input_ids) >= self.max_req_input_len:
|
||||||
@@ -390,16 +380,32 @@ class ModelTpServer:
|
|||||||
"the max context length. Truncated!!!"
|
"the max context length. Truncated!!!"
|
||||||
)
|
)
|
||||||
req.origin_input_ids = req.origin_input_ids[: self.max_req_input_len]
|
req.origin_input_ids = req.origin_input_ids[: self.max_req_input_len]
|
||||||
|
req.sampling_params.max_new_tokens = min(
|
||||||
|
(
|
||||||
|
req.sampling_params.max_new_tokens
|
||||||
|
if req.sampling_params.max_new_tokens is not None
|
||||||
|
else 1 << 30
|
||||||
|
),
|
||||||
|
self.max_req_input_len - 1 - len(req.origin_input_ids),
|
||||||
|
)
|
||||||
|
|
||||||
if self.model_runner.is_generation:
|
self.waiting_queue.append(req)
|
||||||
req.sampling_params.max_new_tokens = min(
|
|
||||||
(
|
def handle_embedding_request(
|
||||||
req.sampling_params.max_new_tokens
|
self,
|
||||||
if req.sampling_params.max_new_tokens is not None
|
recv_req: TokenizedEmbeddingReqInput,
|
||||||
else 1 << 30
|
):
|
||||||
),
|
req = Req(recv_req.rid, recv_req.input_text, recv_req.input_ids)
|
||||||
self.max_req_input_len - 1 - len(req.origin_input_ids),
|
req.tokenizer = self.tokenizer
|
||||||
|
req.sampling_params = recv_req.sampling_params
|
||||||
|
|
||||||
|
# Truncate prompts that are too long
|
||||||
|
if len(req.origin_input_ids) >= self.max_req_input_len:
|
||||||
|
logger.warn(
|
||||||
|
"Request length is longer than the KV cache pool size or "
|
||||||
|
"the max context length. Truncated!!!"
|
||||||
)
|
)
|
||||||
|
req.origin_input_ids = req.origin_input_ids[: self.max_req_input_len]
|
||||||
|
|
||||||
self.waiting_queue.append(req)
|
self.waiting_queue.append(req)
|
||||||
|
|
||||||
@@ -892,7 +898,6 @@ def run_tp_server(
|
|||||||
tp_rank: int,
|
tp_rank: int,
|
||||||
server_args: ServerArgs,
|
server_args: ServerArgs,
|
||||||
nccl_port: int,
|
nccl_port: int,
|
||||||
model_override_args: dict,
|
|
||||||
):
|
):
|
||||||
"""Run a tensor parallel model server."""
|
"""Run a tensor parallel model server."""
|
||||||
configure_logger(server_args, prefix=f" TP{tp_rank}")
|
configure_logger(server_args, prefix=f" TP{tp_rank}")
|
||||||
@@ -903,7 +908,6 @@ def run_tp_server(
|
|||||||
tp_rank,
|
tp_rank,
|
||||||
server_args,
|
server_args,
|
||||||
nccl_port,
|
nccl_port,
|
||||||
model_override_args,
|
|
||||||
)
|
)
|
||||||
tp_cpu_group = model_server.model_runner.tp_group.cpu_group
|
tp_cpu_group = model_server.model_runner.tp_group.cpu_group
|
||||||
|
|
||||||
@@ -920,14 +924,13 @@ def launch_tp_servers(
|
|||||||
tp_rank_range: List[int],
|
tp_rank_range: List[int],
|
||||||
server_args: ServerArgs,
|
server_args: ServerArgs,
|
||||||
nccl_port: int,
|
nccl_port: int,
|
||||||
model_override_args: dict,
|
|
||||||
):
|
):
|
||||||
"""Launch multiple tensor parallel servers."""
|
"""Launch multiple tensor parallel servers."""
|
||||||
procs = []
|
procs = []
|
||||||
for i in tp_rank_range:
|
for i in tp_rank_range:
|
||||||
proc = multiprocessing.Process(
|
proc = multiprocessing.Process(
|
||||||
target=run_tp_server,
|
target=run_tp_server,
|
||||||
args=(gpu_ids[i], i, server_args, nccl_port, model_override_args),
|
args=(gpu_ids[i], i, server_args, nccl_port),
|
||||||
)
|
)
|
||||||
proc.start()
|
proc.start()
|
||||||
procs.append(proc)
|
procs.append(proc)
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ limitations under the License.
|
|||||||
import gc
|
import gc
|
||||||
import importlib
|
import importlib
|
||||||
import importlib.resources
|
import importlib.resources
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
import pkgutil
|
import pkgutil
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
|
|||||||
@@ -272,7 +272,6 @@ async def retrieve_file_content(file_id: str):
|
|||||||
|
|
||||||
def launch_server(
|
def launch_server(
|
||||||
server_args: ServerArgs,
|
server_args: ServerArgs,
|
||||||
model_override_args: Optional[dict] = None,
|
|
||||||
pipe_finish_writer: Optional[mp.connection.Connection] = None,
|
pipe_finish_writer: Optional[mp.connection.Connection] = None,
|
||||||
):
|
):
|
||||||
"""Launch an HTTP server."""
|
"""Launch an HTTP server."""
|
||||||
@@ -317,7 +316,6 @@ def launch_server(
|
|||||||
tp_rank_range,
|
tp_rank_range,
|
||||||
server_args,
|
server_args,
|
||||||
ports[3],
|
ports[3],
|
||||||
model_override_args,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -328,7 +326,7 @@ def launch_server(
|
|||||||
return
|
return
|
||||||
|
|
||||||
# Launch processes
|
# Launch processes
|
||||||
tokenizer_manager = TokenizerManager(server_args, port_args, model_override_args)
|
tokenizer_manager = TokenizerManager(server_args, port_args)
|
||||||
if server_args.chat_template:
|
if server_args.chat_template:
|
||||||
load_chat_template_for_openai_api(tokenizer_manager, server_args.chat_template)
|
load_chat_template_for_openai_api(tokenizer_manager, server_args.chat_template)
|
||||||
pipe_controller_reader, pipe_controller_writer = mp.Pipe(duplex=False)
|
pipe_controller_reader, pipe_controller_writer = mp.Pipe(duplex=False)
|
||||||
@@ -341,7 +339,7 @@ def launch_server(
|
|||||||
|
|
||||||
proc_controller = mp.Process(
|
proc_controller = mp.Process(
|
||||||
target=start_controller_process,
|
target=start_controller_process,
|
||||||
args=(server_args, port_args, pipe_controller_writer, model_override_args),
|
args=(server_args, port_args, pipe_controller_writer),
|
||||||
)
|
)
|
||||||
proc_controller.start()
|
proc_controller.start()
|
||||||
|
|
||||||
@@ -501,7 +499,6 @@ class Runtime:
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
log_level: str = "error",
|
log_level: str = "error",
|
||||||
model_override_args: Optional[dict] = None,
|
|
||||||
*args,
|
*args,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
@@ -525,7 +522,7 @@ class Runtime:
|
|||||||
|
|
||||||
proc = mp.Process(
|
proc = mp.Process(
|
||||||
target=launch_server,
|
target=launch_server,
|
||||||
args=(self.server_args, model_override_args, pipe_writer),
|
args=(self.server_args, pipe_writer),
|
||||||
)
|
)
|
||||||
proc.start()
|
proc.start()
|
||||||
pipe_writer.close()
|
pipe_writer.close()
|
||||||
|
|||||||
@@ -76,6 +76,14 @@ class ServerArgs:
|
|||||||
dp_size: int = 1
|
dp_size: int = 1
|
||||||
load_balance_method: str = "round_robin"
|
load_balance_method: str = "round_robin"
|
||||||
|
|
||||||
|
# Distributed args
|
||||||
|
nccl_init_addr: Optional[str] = None
|
||||||
|
nnodes: int = 1
|
||||||
|
node_rank: Optional[int] = None
|
||||||
|
|
||||||
|
# Model override args in JSON
|
||||||
|
json_model_override_args: str = "{}"
|
||||||
|
|
||||||
# Optimization/debug options
|
# Optimization/debug options
|
||||||
disable_flashinfer: bool = False
|
disable_flashinfer: bool = False
|
||||||
disable_flashinfer_sampling: bool = False
|
disable_flashinfer_sampling: bool = False
|
||||||
@@ -91,14 +99,6 @@ class ServerArgs:
|
|||||||
enable_mla: bool = False
|
enable_mla: bool = False
|
||||||
triton_attention_reduce_in_fp32: bool = False
|
triton_attention_reduce_in_fp32: bool = False
|
||||||
|
|
||||||
# Distributed args
|
|
||||||
nccl_init_addr: Optional[str] = None
|
|
||||||
nnodes: int = 1
|
|
||||||
node_rank: Optional[int] = None
|
|
||||||
|
|
||||||
# Model override args in JSON
|
|
||||||
json_model_override_args: Optional[dict] = None
|
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
if self.tokenizer_path is None:
|
if self.tokenizer_path is None:
|
||||||
self.tokenizer_path = self.model_path
|
self.tokenizer_path = self.model_path
|
||||||
@@ -385,6 +385,14 @@ class ServerArgs:
|
|||||||
)
|
)
|
||||||
parser.add_argument("--node-rank", type=int, help="The node rank.")
|
parser.add_argument("--node-rank", type=int, help="The node rank.")
|
||||||
|
|
||||||
|
# Model override args
|
||||||
|
parser.add_argument(
|
||||||
|
"--json-model-override-args",
|
||||||
|
type=str,
|
||||||
|
help="A dictionary in JSON string format used to override default model configurations.",
|
||||||
|
default=ServerArgs.json_model_override_args,
|
||||||
|
)
|
||||||
|
|
||||||
# Optimization/debug options
|
# Optimization/debug options
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--disable-flashinfer",
|
"--disable-flashinfer",
|
||||||
@@ -459,22 +467,10 @@ class ServerArgs:
|
|||||||
help="Turn on memory efficient weight loading with quantization (quantize per layer during loading).",
|
help="Turn on memory efficient weight loading with quantization (quantize per layer during loading).",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Model override args
|
|
||||||
parser.add_argument(
|
|
||||||
"--json-model-override-args",
|
|
||||||
type=str,
|
|
||||||
help="A dictionary in JSON string format used to override default model configurations.",
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_cli_args(cls, args: argparse.Namespace):
|
def from_cli_args(cls, args: argparse.Namespace):
|
||||||
args.tp_size = args.tensor_parallel_size
|
args.tp_size = args.tensor_parallel_size
|
||||||
args.dp_size = args.data_parallel_size
|
args.dp_size = args.data_parallel_size
|
||||||
args.json_model_override_args = (
|
|
||||||
json.loads(args.json_model_override_args)
|
|
||||||
if args.json_model_override_args
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
attrs = [attr.name for attr in dataclasses.fields(cls)]
|
attrs = [attr.name for attr in dataclasses.fields(cls)]
|
||||||
return cls(**{attr: getattr(args, attr) for attr in attrs})
|
return cls(**{attr: getattr(args, attr) for attr in attrs})
|
||||||
|
|
||||||
@@ -498,7 +494,7 @@ class ServerArgs:
|
|||||||
self.disable_flashinfer = False
|
self.disable_flashinfer = False
|
||||||
|
|
||||||
|
|
||||||
def prepare_server_args(args: argparse.Namespace) -> ServerArgs:
|
def prepare_server_args(argv: List[str]) -> ServerArgs:
|
||||||
"""
|
"""
|
||||||
Prepare the server arguments from the command line arguments.
|
Prepare the server arguments from the command line arguments.
|
||||||
|
|
||||||
@@ -511,7 +507,7 @@ def prepare_server_args(args: argparse.Namespace) -> ServerArgs:
|
|||||||
"""
|
"""
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
ServerArgs.add_cli_args(parser)
|
ServerArgs.add_cli_args(parser)
|
||||||
raw_args = parser.parse_args(args)
|
raw_args = parser.parse_args(argv)
|
||||||
server_args = ServerArgs.from_cli_args(raw_args)
|
server_args = ServerArgs.from_cli_args(raw_args)
|
||||||
return server_args
|
return server_args
|
||||||
|
|
||||||
|
|||||||
132
python/sglang/test/few_shot_gsm8k.py
Normal file
132
python/sglang/test/few_shot_gsm8k.py
Normal file
@@ -0,0 +1,132 @@
|
|||||||
|
"""
|
||||||
|
Run few-shot GSM-8K evaluation.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
python3 -m sglang.test.few_shot_gsm8k --num-questions 200
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import ast
|
||||||
|
import re
|
||||||
|
import time
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from sglang.api import set_default_backend
|
||||||
|
from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint
|
||||||
|
from sglang.utils import download_and_cache_file, dump_state_text, read_jsonl
|
||||||
|
|
||||||
|
INVALID = -9999999
|
||||||
|
|
||||||
|
|
||||||
|
def get_one_example(lines, i, include_answer):
|
||||||
|
ret = "Question: " + lines[i]["question"] + "\nAnswer:"
|
||||||
|
if include_answer:
|
||||||
|
ret += " " + lines[i]["answer"]
|
||||||
|
return ret
|
||||||
|
|
||||||
|
|
||||||
|
def get_few_shot_examples(lines, k):
|
||||||
|
ret = ""
|
||||||
|
for i in range(k):
|
||||||
|
ret += get_one_example(lines, i, True) + "\n\n"
|
||||||
|
return ret
|
||||||
|
|
||||||
|
|
||||||
|
def get_answer_value(answer_str):
|
||||||
|
answer_str = answer_str.replace(",", "")
|
||||||
|
numbers = re.findall(r"\d+", answer_str)
|
||||||
|
if len(numbers) < 1:
|
||||||
|
return INVALID
|
||||||
|
try:
|
||||||
|
return ast.literal_eval(numbers[-1])
|
||||||
|
except SyntaxError:
|
||||||
|
return INVALID
|
||||||
|
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
# Select backend
|
||||||
|
set_default_backend(RuntimeEndpoint(f"{args.host}:{args.port}"))
|
||||||
|
|
||||||
|
# Read data
|
||||||
|
url = "https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl"
|
||||||
|
filename = download_and_cache_file(url)
|
||||||
|
lines = list(read_jsonl(filename))
|
||||||
|
|
||||||
|
# Construct prompts
|
||||||
|
num_questions = args.num_questions
|
||||||
|
num_shots = args.num_shots
|
||||||
|
few_shot_examples = get_few_shot_examples(lines, num_shots)
|
||||||
|
|
||||||
|
questions = []
|
||||||
|
labels = []
|
||||||
|
for i in range(len(lines[:num_questions])):
|
||||||
|
questions.append(get_one_example(lines, i, False))
|
||||||
|
labels.append(get_answer_value(lines[i]["answer"]))
|
||||||
|
assert all(l != INVALID for l in labels)
|
||||||
|
arguments = [{"question": q} for q in questions]
|
||||||
|
|
||||||
|
#####################################
|
||||||
|
######### SGL Program Begin #########
|
||||||
|
#####################################
|
||||||
|
|
||||||
|
import sglang as sgl
|
||||||
|
|
||||||
|
@sgl.function
|
||||||
|
def few_shot_gsm8k(s, question):
|
||||||
|
s += few_shot_examples + question
|
||||||
|
s += sgl.gen(
|
||||||
|
"answer", max_tokens=512, stop=["Question", "Assistant:", "<|separator|>"]
|
||||||
|
)
|
||||||
|
|
||||||
|
#####################################
|
||||||
|
########## SGL Program End ##########
|
||||||
|
#####################################
|
||||||
|
|
||||||
|
# Run requests
|
||||||
|
tic = time.time()
|
||||||
|
states = few_shot_gsm8k.run_batch(
|
||||||
|
arguments,
|
||||||
|
temperature=0,
|
||||||
|
num_threads=args.parallel,
|
||||||
|
progress_bar=True,
|
||||||
|
)
|
||||||
|
latency = time.time() - tic
|
||||||
|
|
||||||
|
preds = []
|
||||||
|
for i in range(len(states)):
|
||||||
|
preds.append(get_answer_value(states[i]["answer"]))
|
||||||
|
|
||||||
|
# print(f"{preds=}")
|
||||||
|
# print(f"{labels=}")
|
||||||
|
|
||||||
|
# Compute accuracy
|
||||||
|
acc = np.mean(np.array(preds) == np.array(labels))
|
||||||
|
invalid = np.mean(np.array(preds) == INVALID)
|
||||||
|
|
||||||
|
# Compute speed
|
||||||
|
num_output_tokens = sum(
|
||||||
|
s.get_meta_info("answer")["completion_tokens"] for s in states
|
||||||
|
)
|
||||||
|
output_throughput = num_output_tokens / latency
|
||||||
|
|
||||||
|
# Print results
|
||||||
|
print(f"Accuracy: {acc:.3f}")
|
||||||
|
print(f"Invalid: {invalid:.3f}")
|
||||||
|
print(f"Latency: {latency:.3f} s")
|
||||||
|
print(f"Output throughput: {output_throughput:.3f} token/s")
|
||||||
|
|
||||||
|
# Dump results
|
||||||
|
dump_state_text("tmp_output_gsm8k.txt", states)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--num-shots", type=int, default=5)
|
||||||
|
parser.add_argument("--data-path", type=str, default="test.jsonl")
|
||||||
|
parser.add_argument("--num-questions", type=int, default=200)
|
||||||
|
parser.add_argument("--parallel", type=int, default=128)
|
||||||
|
parser.add_argument("--host", type=str, default="http://127.0.0.1")
|
||||||
|
parser.add_argument("--port", type=int, default=30000)
|
||||||
|
args = parser.parse_args()
|
||||||
|
main(args)
|
||||||
@@ -7,7 +7,7 @@ import time
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
import sglang as sgl
|
import sglang as sgl
|
||||||
from sglang.utils import fetch_and_cache_jsonl
|
from sglang.utils import download_and_cache_file, read_jsonl
|
||||||
|
|
||||||
|
|
||||||
def test_few_shot_qa():
|
def test_few_shot_qa():
|
||||||
@@ -456,10 +456,6 @@ def test_chat_completion_speculative():
|
|||||||
def test_hellaswag_select():
|
def test_hellaswag_select():
|
||||||
"""Benchmark the accuracy of sgl.select on the HellaSwag dataset."""
|
"""Benchmark the accuracy of sgl.select on the HellaSwag dataset."""
|
||||||
|
|
||||||
url = "https://raw.githubusercontent.com/rowanz/hellaswag/master/data/hellaswag_val.jsonl"
|
|
||||||
lines = fetch_and_cache_jsonl(url)
|
|
||||||
|
|
||||||
# Construct prompts
|
|
||||||
def get_one_example(lines, i, include_answer):
|
def get_one_example(lines, i, include_answer):
|
||||||
ret = lines[i]["activity_label"] + ": " + lines[i]["ctx"] + " "
|
ret = lines[i]["activity_label"] + ": " + lines[i]["ctx"] + " "
|
||||||
if include_answer:
|
if include_answer:
|
||||||
@@ -472,6 +468,12 @@ def test_hellaswag_select():
|
|||||||
ret += get_one_example(lines, i, True) + "\n\n"
|
ret += get_one_example(lines, i, True) + "\n\n"
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
# Read data
|
||||||
|
url = "https://raw.githubusercontent.com/rowanz/hellaswag/master/data/hellaswag_val.jsonl"
|
||||||
|
filename = download_and_cache_file(url)
|
||||||
|
lines = list(read_jsonl(filename))
|
||||||
|
|
||||||
|
# Construct prompts
|
||||||
num_questions = 200
|
num_questions = 200
|
||||||
num_shots = 20
|
num_shots = 20
|
||||||
few_shot_examples = get_few_shot_examples(lines, num_shots)
|
few_shot_examples = get_few_shot_examples(lines, num_shots)
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ import urllib.request
|
|||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from json import dumps
|
from json import dumps
|
||||||
from typing import Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import requests
|
import requests
|
||||||
@@ -38,13 +38,11 @@ def is_same_type(values: list):
|
|||||||
|
|
||||||
def read_jsonl(filename: str):
|
def read_jsonl(filename: str):
|
||||||
"""Read a JSONL file."""
|
"""Read a JSONL file."""
|
||||||
rets = []
|
|
||||||
with open(filename) as fin:
|
with open(filename) as fin:
|
||||||
for line in fin:
|
for line in fin:
|
||||||
if line.startswith("#"):
|
if line.startswith("#"):
|
||||||
continue
|
continue
|
||||||
rets.append(json.loads(line))
|
yield json.loads(line)
|
||||||
return rets
|
|
||||||
|
|
||||||
|
|
||||||
def dump_state_text(filename: str, states: list, mode: str = "w"):
|
def dump_state_text(filename: str, states: list, mode: str = "w"):
|
||||||
@@ -264,38 +262,35 @@ class LazyImport:
|
|||||||
return module(*args, **kwargs)
|
return module(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def fetch_and_cache_jsonl(url, cache_file="cached_data.jsonl"):
|
def download_and_cache_file(url: str, filename: Optional[str] = None):
|
||||||
"""Read and cache a jsonl file from a url."""
|
"""Read and cache a file from a url."""
|
||||||
|
if filename is None:
|
||||||
|
filename = os.path.join("/tmp", url.split("/")[-1])
|
||||||
|
|
||||||
# Check if the cache file already exists
|
# Check if the cache file already exists
|
||||||
if os.path.exists(cache_file):
|
if os.path.exists(filename):
|
||||||
print("Loading data from cache...")
|
return filename
|
||||||
with open(cache_file, "r") as f:
|
|
||||||
data = [json.loads(line) for line in f]
|
|
||||||
else:
|
|
||||||
print("Downloading data from URL...")
|
|
||||||
# Stream the response to show the progress bar
|
|
||||||
response = requests.get(url, stream=True)
|
|
||||||
response.raise_for_status() # Check for request errors
|
|
||||||
|
|
||||||
# Total size of the file in bytes
|
print(f"Downloading from {url} to {filename}")
|
||||||
total_size = int(response.headers.get("content-length", 0))
|
|
||||||
chunk_size = 1024 # Download in chunks of 1KB
|
|
||||||
|
|
||||||
# Use tqdm to display the progress bar
|
# Stream the response to show the progress bar
|
||||||
with open(cache_file, "wb") as f, tqdm(
|
response = requests.get(url, stream=True)
|
||||||
desc=cache_file,
|
response.raise_for_status() # Check for request errors
|
||||||
total=total_size,
|
|
||||||
unit="B",
|
|
||||||
unit_scale=True,
|
|
||||||
unit_divisor=1024,
|
|
||||||
) as bar:
|
|
||||||
for chunk in response.iter_content(chunk_size=chunk_size):
|
|
||||||
f.write(chunk)
|
|
||||||
bar.update(len(chunk))
|
|
||||||
|
|
||||||
# Convert the data to a list of dictionaries
|
# Total size of the file in bytes
|
||||||
with open(cache_file, "r") as f:
|
total_size = int(response.headers.get("content-length", 0))
|
||||||
data = [json.loads(line) for line in f]
|
chunk_size = 1024 # Download in chunks of 1KB
|
||||||
|
|
||||||
return data
|
# Use tqdm to display the progress bar
|
||||||
|
with open(filename, "wb") as f, tqdm(
|
||||||
|
desc=filename,
|
||||||
|
total=total_size,
|
||||||
|
unit="B",
|
||||||
|
unit_scale=True,
|
||||||
|
unit_divisor=1024,
|
||||||
|
) as bar:
|
||||||
|
for chunk in response.iter_content(chunk_size=chunk_size):
|
||||||
|
f.write(chunk)
|
||||||
|
bar.update(len(chunk))
|
||||||
|
|
||||||
|
return filename
|
||||||
|
|||||||
@@ -42,7 +42,7 @@ class TestEvalAccuracyLarge(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
metrics = run_eval(args)
|
metrics = run_eval(args)
|
||||||
assert metrics["score"] >= 0.63, f"{metrics}"
|
assert metrics["score"] >= 0.62, f"{metrics}"
|
||||||
|
|
||||||
def test_human_eval(self):
|
def test_human_eval(self):
|
||||||
args = SimpleNamespace(
|
args = SimpleNamespace(
|
||||||
@@ -66,7 +66,7 @@ class TestEvalAccuracyLarge(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
metrics = run_eval(args)
|
metrics = run_eval(args)
|
||||||
assert metrics["score"] >= 0.63, f"{metrics}"
|
assert metrics["score"] >= 0.62, f"{metrics}"
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
import json
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from sglang.srt.server_args import prepare_server_args
|
from sglang.srt.server_args import prepare_server_args
|
||||||
@@ -15,7 +16,7 @@ class TestPrepareServerArgs(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
self.assertEqual(server_args.model_path, "model_path")
|
self.assertEqual(server_args.model_path, "model_path")
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
server_args.json_model_override_args,
|
json.loads(server_args.json_model_override_args),
|
||||||
{"rope_scaling": {"factor": 2.0, "type": "linear"}},
|
{"rope_scaling": {"factor": 2.0, "type": "linear"}},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user