[Minor] Many cleanup (#1357)

This commit is contained in:
Lianmin Zheng
2024-09-09 04:14:11 -07:00
committed by GitHub
parent c9b75917d5
commit e4d68afcf0
24 changed files with 416 additions and 296 deletions

View File

@@ -1,8 +1,3 @@
## Download data
```
bash download_data.sh
```
## Run benchmark ## Run benchmark
### Benchmark sglang ### Benchmark sglang

View File

@@ -10,7 +10,7 @@ import numpy as np
from tqdm import tqdm from tqdm import tqdm
from sglang.test.test_utils import add_common_other_args_and_parse, get_call_generate from sglang.test.test_utils import add_common_other_args_and_parse, get_call_generate
from sglang.utils import dump_state_text, read_jsonl from sglang.utils import download_and_cache_file, dump_state_text, read_jsonl
INVALID = -9999999 INVALID = -9999999
@@ -41,24 +41,28 @@ def get_answer_value(answer_str):
def main(args): def main(args):
lines = read_jsonl(args.data_path) # Select backend
call_generate = get_call_generate(args)
# Read data
url = "https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl"
filename = download_and_cache_file(url)
lines = list(read_jsonl(filename))
# Construct prompts # Construct prompts
k = args.num_shot num_questions = args.num_questions
few_shot_examples = get_few_shot_examples(lines, k) num_shots = args.num_shots
few_shot_examples = get_few_shot_examples(lines, num_shots)
questions = [] questions = []
labels = [] labels = []
for i in range(len(lines[: args.num_questions])): for i in range(len(lines[:num_questions])):
questions.append(get_one_example(lines, i, False)) questions.append(get_one_example(lines, i, False))
labels.append(get_answer_value(lines[i]["answer"])) labels.append(get_answer_value(lines[i]["answer"]))
assert all(l != INVALID for l in labels) assert all(l != INVALID for l in labels)
states = [None] * len(labels) states = [None] * len(labels)
# Select backend
call_generate = get_call_generate(args)
# Run requests # Run requests
if args.backend != "lmql": if args.backend != "lmql":
# Use thread pool # Use thread pool
@@ -113,11 +117,13 @@ def main(args):
# Compute accuracy # Compute accuracy
acc = np.mean(np.array(preds) == np.array(labels)) acc = np.mean(np.array(preds) == np.array(labels))
invalid = np.mean(np.array(preds) == INVALID) invalid = np.mean(np.array(preds) == INVALID)
print(f"Latency: {latency:.3f}")
print(f"Invalid: {invalid:.3f}")
print(f"Accuracy: {acc:.3f}")
# Write results # Print results
print(f"Accuracy: {acc:.3f}")
print(f"Invalid: {invalid:.3f}")
print(f"Latency: {latency:.3f} s")
# Dump results
dump_state_text(f"tmp_output_{args.backend}.txt", states) dump_state_text(f"tmp_output_{args.backend}.txt", states)
with open(args.result_file, "a") as fout: with open(args.result_file, "a") as fout:
@@ -138,7 +144,7 @@ def main(args):
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--num-shot", type=int, default=5) parser.add_argument("--num-shots", type=int, default=5)
parser.add_argument("--data-path", type=str, default="test.jsonl") parser.add_argument("--data-path", type=str, default="test.jsonl")
parser.add_argument("--num-questions", type=int, default=200) parser.add_argument("--num-questions", type=int, default=200)
args = add_common_other_args_and_parse(parser) args = add_common_other_args_and_parse(parser)

View File

@@ -6,11 +6,12 @@ import time
import numpy as np import numpy as np
from sglang.api import set_default_backend
from sglang.test.test_utils import ( from sglang.test.test_utils import (
add_common_sglang_args_and_parse, add_common_sglang_args_and_parse,
select_sglang_backend, select_sglang_backend,
) )
from sglang.utils import dump_state_text, read_jsonl from sglang.utils import download_and_cache_file, dump_state_text, read_jsonl
INVALID = -9999999 INVALID = -9999999
@@ -41,15 +42,22 @@ def get_answer_value(answer_str):
def main(args): def main(args):
lines = read_jsonl(args.data_path) # Select backend
set_default_backend(select_sglang_backend(args))
# Read data
url = "https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl"
filename = download_and_cache_file(url)
lines = list(read_jsonl(filename))
# Construct prompts # Construct prompts
k = args.num_shot num_questions = args.num_questions
few_shot_examples = get_few_shot_examples(lines, k) num_shots = args.num_shots
few_shot_examples = get_few_shot_examples(lines, num_shots)
questions = [] questions = []
labels = [] labels = []
for i in range(len(lines[: args.num_questions])): for i in range(len(lines[:num_questions])):
questions.append(get_one_example(lines, i, False)) questions.append(get_one_example(lines, i, False))
labels.append(get_answer_value(lines[i]["answer"])) labels.append(get_answer_value(lines[i]["answer"]))
assert all(l != INVALID for l in labels) assert all(l != INVALID for l in labels)
@@ -72,15 +80,11 @@ def main(args):
########## SGL Program End ########## ########## SGL Program End ##########
##################################### #####################################
# Select backend
backend = select_sglang_backend(args)
# Run requests # Run requests
tic = time.time() tic = time.time()
states = few_shot_gsm8k.run_batch( states = few_shot_gsm8k.run_batch(
arguments, arguments,
temperature=0, temperature=0,
backend=backend,
num_threads=args.parallel, num_threads=args.parallel,
progress_bar=True, progress_bar=True,
) )
@@ -96,11 +100,20 @@ def main(args):
# Compute accuracy # Compute accuracy
acc = np.mean(np.array(preds) == np.array(labels)) acc = np.mean(np.array(preds) == np.array(labels))
invalid = np.mean(np.array(preds) == INVALID) invalid = np.mean(np.array(preds) == INVALID)
print(f"Latency: {latency:.3f}")
print(f"Invalid: {invalid:.3f}")
print(f"Accuracy: {acc:.3f}")
# Write results # Compute speed
num_output_tokens = sum(
s.get_meta_info("answer")["completion_tokens"] for s in states
)
output_throughput = num_output_tokens / latency
# Print results
print(f"Accuracy: {acc:.3f}")
print(f"Invalid: {invalid:.3f}")
print(f"Latency: {latency:.3f} s")
print(f"Output throughput: {output_throughput:.3f} token/s")
# Dump results
dump_state_text(f"tmp_output_{args.backend}.txt", states) dump_state_text(f"tmp_output_{args.backend}.txt", states)
with open(args.result_file, "a") as fout: with open(args.result_file, "a") as fout:
@@ -121,7 +134,7 @@ def main(args):
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--num-shot", type=int, default=5) parser.add_argument("--num-shots", type=int, default=5)
parser.add_argument("--data-path", type=str, default="test.jsonl") parser.add_argument("--data-path", type=str, default="test.jsonl")
parser.add_argument("--num-questions", type=int, default=200) parser.add_argument("--num-questions", type=int, default=200)
args = add_common_sglang_args_and_parse(parser) args = add_common_sglang_args_and_parse(parser)

View File

@@ -1,2 +0,0 @@
wget https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/train.jsonl
wget https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl

View File

@@ -1,8 +1,3 @@
## Download data
```
wget https://raw.githubusercontent.com/rowanz/hellaswag/master/data/hellaswag_val.jsonl
```
## Run benchmark ## Run benchmark
### Benchmark sglang ### Benchmark sglang

View File

@@ -8,7 +8,7 @@ import numpy as np
from tqdm import tqdm from tqdm import tqdm
from sglang.test.test_utils import add_common_other_args_and_parse, get_call_select from sglang.test.test_utils import add_common_other_args_and_parse, get_call_select
from sglang.utils import read_jsonl from sglang.utils import download_and_cache_file, read_jsonl
def get_one_example(lines, i, include_answer): def get_one_example(lines, i, include_answer):
@@ -26,25 +26,29 @@ def get_few_shot_examples(lines, k):
def main(args): def main(args):
lines = read_jsonl(args.data_path) # Select backend
call_select = get_call_select(args)
# Read data
url = "https://raw.githubusercontent.com/rowanz/hellaswag/master/data/hellaswag_val.jsonl"
filename = download_and_cache_file(url)
lines = list(read_jsonl(filename))
# Construct prompts # Construct prompts
k = args.num_shot num_questions = args.num_questions
few_shot_examples = get_few_shot_examples(lines, k) num_shots = args.num_shots
few_shot_examples = get_few_shot_examples(lines, num_shots)
questions = [] questions = []
choices = [] choices = []
labels = [] labels = []
for i in range(len(lines[: args.num_questions])): for i in range(len(lines[:num_questions])):
questions.append(get_one_example(lines, i, False)) questions.append(get_one_example(lines, i, False))
choices.append(lines[i]["endings"]) choices.append(lines[i]["endings"])
labels.append(lines[i]["label"]) labels.append(lines[i]["label"])
preds = [None] * len(labels) preds = [None] * len(labels)
# Select backend
call_select = get_call_select(args)
# Run requests # Run requests
if args.backend != "lmql": if args.backend != "lmql":
# Use thread pool # Use thread pool
@@ -65,7 +69,6 @@ def main(args):
total=len(questions), total=len(questions),
) )
) )
else: else:
# Use asyncio # Use asyncio
async def batched_call(batch_size): async def batched_call(batch_size):
@@ -108,7 +111,7 @@ def main(args):
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--num-shot", type=int, default=20) parser.add_argument("--num-shots", type=int, default=20)
parser.add_argument("--data-path", type=str, default="hellaswag_val.jsonl") parser.add_argument("--data-path", type=str, default="hellaswag_val.jsonl")
parser.add_argument("--num-questions", type=int, default=200) parser.add_argument("--num-questions", type=int, default=200)
args = add_common_other_args_and_parse(parser) args = add_common_other_args_and_parse(parser)

View File

@@ -4,11 +4,12 @@ import time
import numpy as np import numpy as np
from sglang.api import set_default_backend
from sglang.test.test_utils import ( from sglang.test.test_utils import (
add_common_sglang_args_and_parse, add_common_sglang_args_and_parse,
select_sglang_backend, select_sglang_backend,
) )
from sglang.utils import read_jsonl from sglang.utils import download_and_cache_file, read_jsonl
def get_one_example(lines, i, include_answer): def get_one_example(lines, i, include_answer):
@@ -26,16 +27,23 @@ def get_few_shot_examples(lines, k):
def main(args): def main(args):
lines = read_jsonl(args.data_path) # Select backend
set_default_backend(select_sglang_backend(args))
# Read data
url = "https://raw.githubusercontent.com/rowanz/hellaswag/master/data/hellaswag_val.jsonl"
filename = download_and_cache_file(url)
lines = list(read_jsonl(filename))
# Construct prompts # Construct prompts
k = args.num_shot num_questions = args.num_questions
few_shot_examples = get_few_shot_examples(lines, k) num_shots = args.num_shots
few_shot_examples = get_few_shot_examples(lines, num_shots)
questions = [] questions = []
choices = [] choices = []
labels = [] labels = []
for i in range(len(lines[: args.num_questions])): for i in range(len(lines[:num_questions])):
questions.append(get_one_example(lines, i, False)) questions.append(get_one_example(lines, i, False))
choices.append(lines[i]["endings"]) choices.append(lines[i]["endings"])
labels.append(lines[i]["label"]) labels.append(lines[i]["label"])
@@ -56,15 +64,11 @@ def main(args):
########## SGL Program End ########## ########## SGL Program End ##########
##################################### #####################################
# Select backend
backend = select_sglang_backend(args)
# Run requests # Run requests
tic = time.time() tic = time.time()
rets = few_shot_hellaswag.run_batch( rets = few_shot_hellaswag.run_batch(
arguments, arguments,
temperature=0, temperature=0,
backend=backend,
num_threads=args.parallel, num_threads=args.parallel,
progress_bar=True, progress_bar=True,
) )
@@ -95,7 +99,7 @@ def main(args):
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--num-shot", type=int, default=20) parser.add_argument("--num-shots", type=int, default=20)
parser.add_argument("--data-path", type=str, default="hellaswag_val.jsonl") parser.add_argument("--data-path", type=str, default="hellaswag_val.jsonl")
parser.add_argument("--num-questions", type=int, default=200) parser.add_argument("--num-questions", type=int, default=200)
args = add_common_sglang_args_and_parse(parser) args = add_common_sglang_args_and_parse(parser)

View File

@@ -7,6 +7,7 @@ python3 srt_example_llava_v.py
import argparse import argparse
import csv import csv
import json
import os import os
import time import time
@@ -223,7 +224,7 @@ if __name__ == "__main__":
tokenizer_path=tokenizer_path, tokenizer_path=tokenizer_path,
port=cur_port, port=cur_port,
additional_ports=[cur_port + 1, cur_port + 2, cur_port + 3, cur_port + 4], additional_ports=[cur_port + 1, cur_port + 2, cur_port + 3, cur_port + 4],
model_override_args=model_override_args, json_model_override_args=json.dumps(model_override_args),
tp_size=1, tp_size=1,
) )
sgl.set_default_backend(runtime) sgl.set_default_backend(runtime)

View File

@@ -298,34 +298,41 @@ class BenchmarkMetrics:
median_e2e_latency_ms: float median_e2e_latency_ms: float
default_sharegpt_path = "ShareGPT_V3_unfiltered_cleaned_split.json" SHAREGPT_URL = "https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json"
def download_sharegpt_dataset(path): def download_and_cache_file(url: str, filename: Optional[str] = None):
url = "https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json" """Read and cache a file from a url."""
if filename is None:
filename = os.path.join("/tmp", url.split("/")[-1])
print(f"Downloading dataset from {url}") # Check if the cache file already exists
try: if os.path.exists(filename):
response = requests.get(url, stream=True) return filename
response.raise_for_status()
total_size = int(response.headers.get("content-length", 0)) print(f"Downloading from {url} to {filename}")
block_size = 8192
with open(path, "wb") as f, tqdm( # Stream the response to show the progress bar
desc="Downloading", response = requests.get(url, stream=True)
total=total_size, response.raise_for_status() # Check for request errors
unit="iB",
unit_scale=True,
unit_divisor=1024,
) as progress_bar:
for data in response.iter_content(block_size):
size = f.write(data)
progress_bar.update(size)
print(f"Dataset downloaded and saved to {path}") # Total size of the file in bytes
except requests.RequestException as e: total_size = int(response.headers.get("content-length", 0))
raise Exception(f"Failed to download dataset: {e}") chunk_size = 1024 # Download in chunks of 1KB
# Use tqdm to display the progress bar
with open(filename, "wb") as f, tqdm(
desc=filename,
total=total_size,
unit="B",
unit_scale=True,
unit_divisor=1024,
) as bar:
for chunk in response.iter_content(chunk_size=chunk_size):
f.write(chunk)
bar.update(len(chunk))
return filename
def sample_sharegpt_requests( def sample_sharegpt_requests(
@@ -338,13 +345,8 @@ def sample_sharegpt_requests(
raise ValueError("output_len too small") raise ValueError("output_len too small")
# Download sharegpt if necessary # Download sharegpt if necessary
if not os.path.isfile(dataset_path) and not os.path.isfile(default_sharegpt_path): if not os.path.isfile(dataset_path):
download_sharegpt_dataset(default_sharegpt_path) dataset_path = download_and_cache_file(SHAREGPT_URL)
dataset_path = default_sharegpt_path
else:
dataset_path = (
dataset_path if os.path.isfile(dataset_path) else default_sharegpt_path
)
# Load the dataset. # Load the dataset.
with open(dataset_path) as f: with open(dataset_path) as f:
@@ -412,15 +414,8 @@ def sample_random_requests(
# Sample token ids from ShareGPT and repeat/truncate them to satisfy the input_lens # Sample token ids from ShareGPT and repeat/truncate them to satisfy the input_lens
# Download sharegpt if necessary # Download sharegpt if necessary
if not os.path.isfile(dataset_path) and not os.path.isfile( if not os.path.isfile(dataset_path):
default_sharegpt_path dataset_path = download_and_cache_file(SHAREGPT_URL)
):
download_sharegpt_dataset(default_sharegpt_path)
dataset_path = default_sharegpt_path
else:
dataset_path = (
dataset_path if os.path.isfile(dataset_path) else default_sharegpt_path
)
# Load the dataset. # Load the dataset.
with open(dataset_path) as f: with open(dataset_path) as f:

View File

@@ -9,10 +9,9 @@ from sglang.srt.utils import kill_child_process
if __name__ == "__main__": if __name__ == "__main__":
server_args = prepare_server_args(sys.argv[1:]) server_args = prepare_server_args(sys.argv[1:])
model_override_args = server_args.json_model_override_args
try: try:
launch_server(server_args, model_override_args=model_override_args) launch_server(server_args)
except Exception as e: except Exception as e:
raise e raise e
finally: finally:

View File

@@ -1,5 +1,6 @@
"""Launch the inference server for Llava-video model.""" """Launch the inference server for Llava-video model."""
import json
import sys import sys
from sglang.srt.server import launch_server, prepare_server_args from sglang.srt.server import launch_server, prepare_server_args
@@ -19,5 +20,6 @@ if __name__ == "__main__":
model_override_args["model_max_length"] = 4096 * 2 model_override_args["model_max_length"] = 4096 * 2
if "34b" in server_args.model_path.lower(): if "34b" in server_args.model_path.lower():
model_override_args["image_token_index"] = 64002 model_override_args["image_token_index"] = 64002
server_args.json_model_override_args = json.dumps(model_override_args)
launch_server(server_args, model_override_args, None) launch_server(server_args)

View File

@@ -16,6 +16,7 @@ limitations under the License.
"""Cache for the compressed finite state machine.""" """Cache for the compressed finite state machine."""
from outlines.fsm.json_schema import build_regex_from_schema from outlines.fsm.json_schema import build_regex_from_schema
from transformers import AutoTokenizer
from sglang.srt.constrained import RegexGuide, TransformerTokenizer from sglang.srt.constrained import RegexGuide, TransformerTokenizer
from sglang.srt.constrained.base_tool_cache import BaseToolCache from sglang.srt.constrained.base_tool_cache import BaseToolCache
@@ -28,12 +29,9 @@ class FSMCache(BaseToolCache):
tokenizer_args_dict, tokenizer_args_dict,
enable=True, enable=True,
skip_tokenizer_init=False, skip_tokenizer_init=False,
json_schema_mode=False,
): ):
super().__init__(enable=enable) super().__init__(enable=enable)
self.json_schema_mode = json_schema_mode
if ( if (
skip_tokenizer_init skip_tokenizer_init
or tokenizer_path.endswith(".json") or tokenizer_path.endswith(".json")
@@ -42,44 +40,37 @@ class FSMCache(BaseToolCache):
# Do not support TiktokenTokenizer or SentencePieceTokenizer # Do not support TiktokenTokenizer or SentencePieceTokenizer
return return
from importlib.metadata import version tokenizer_args_dict.setdefault("padding_side", "left")
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, **tokenizer_args_dict)
try:
self.outlines_tokenizer = TransformerTokenizer(tokenizer)
except AttributeError:
# FIXME: tmp fix for chatglm2 & chatglm3 (pad_token_id=0)
origin_pad_token_id = tokenizer.pad_token_id
if version("outlines") >= "0.0.35": def fset(self, value):
from transformers import AutoTokenizer self._value = value
tokenizer_args_dict.setdefault("padding_side", "left") type(tokenizer).pad_token_id = property(
tokenizer = AutoTokenizer.from_pretrained( fget=type(tokenizer).pad_token_id.fget, fset=fset
tokenizer_path, **tokenizer_args_dict
) )
try: self.outlines_tokenizer = TransformerTokenizer(tokenizer)
self.outlines_tokenizer = TransformerTokenizer(tokenizer) self.outlines_tokenizer.tokenizer.pad_token_id = origin_pad_token_id
except AttributeError: self.outlines_tokenizer.pad_token_id = origin_pad_token_id
# FIXME: tmp fix for chatglm2 & chatglm3 (pad_token_id=0) self.outlines_tokenizer.pad_token = (
origin_pad_token_id = tokenizer.pad_token_id self.outlines_tokenizer.tokenizer.pad_token
)
def fset(self, value): self.outlines_tokenizer.vocabulary = (
self._value = value self.outlines_tokenizer.tokenizer.get_vocab()
type(tokenizer).pad_token_id = property(
fget=type(tokenizer).pad_token_id.fget, fset=fset
)
self.outlines_tokenizer = TransformerTokenizer(tokenizer)
self.outlines_tokenizer.tokenizer.pad_token_id = origin_pad_token_id
self.outlines_tokenizer.pad_token_id = origin_pad_token_id
self.outlines_tokenizer.pad_token = (
self.outlines_tokenizer.tokenizer.pad_token
)
self.outlines_tokenizer.vocabulary = (
self.outlines_tokenizer.tokenizer.get_vocab()
)
else:
self.outlines_tokenizer = TransformerTokenizer(
tokenizer_path, **tokenizer_args_dict
) )
def init_value(self, value): def init_value(self, key):
if self.json_schema_mode: key_type, key_string = key
regex = build_regex_from_schema(value, whitespace_pattern=r"[\n\t ]*") if key_type == "json":
return RegexGuide(regex, self.outlines_tokenizer), regex regex = build_regex_from_schema(key_string, whitespace_pattern=r"[\n\t ]*")
elif key_type == "regex":
regex = key_string
else: else:
return RegexGuide(value, self.outlines_tokenizer) raise ValueError(f"Invalid key_type: {key_type}")
return RegexGuide(regex, self.outlines_tokenizer), regex

View File

@@ -71,12 +71,10 @@ class ControllerMulti:
self, self,
server_args: ServerArgs, server_args: ServerArgs,
port_args: PortArgs, port_args: PortArgs,
model_override_args,
): ):
# Parse args # Parse args
self.server_args = server_args self.server_args = server_args
self.port_args = port_args self.port_args = port_args
self.model_override_args = model_override_args
self.load_balance_method = LoadBalanceMethod.from_str( self.load_balance_method = LoadBalanceMethod.from_str(
server_args.load_balance_method server_args.load_balance_method
) )
@@ -114,7 +112,6 @@ class ControllerMulti:
self.server_args, self.server_args,
self.port_args, self.port_args,
pipe_controller_writer, pipe_controller_writer,
self.model_override_args,
True, True,
gpu_ids, gpu_ids,
dp_worker_id, dp_worker_id,
@@ -189,14 +186,13 @@ def start_controller_process(
server_args: ServerArgs, server_args: ServerArgs,
port_args: PortArgs, port_args: PortArgs,
pipe_writer, pipe_writer,
model_override_args: dict,
): ):
"""Start a controller process.""" """Start a controller process."""
configure_logger(server_args) configure_logger(server_args)
try: try:
controller = ControllerMulti(server_args, port_args, model_override_args) controller = ControllerMulti(server_args, port_args)
except Exception: except Exception:
pipe_writer.send(get_exception_traceback()) pipe_writer.send(get_exception_traceback())
raise raise

View File

@@ -40,7 +40,6 @@ class ControllerSingle:
self, self,
server_args: ServerArgs, server_args: ServerArgs,
port_args: PortArgs, port_args: PortArgs,
model_override_args: dict,
gpu_ids: List[int], gpu_ids: List[int],
is_data_parallel_worker: bool, is_data_parallel_worker: bool,
dp_worker_id: int, dp_worker_id: int,
@@ -76,7 +75,6 @@ class ControllerSingle:
tp_rank_range, tp_rank_range,
server_args, server_args,
port_args.nccl_ports[dp_worker_id], port_args.nccl_ports[dp_worker_id],
model_override_args,
) )
# Launch tp rank 0 # Launch tp rank 0
@@ -85,7 +83,6 @@ class ControllerSingle:
0, 0,
server_args, server_args,
port_args.nccl_ports[dp_worker_id], port_args.nccl_ports[dp_worker_id],
model_override_args,
) )
self.tp_cpu_group = self.tp_server.model_runner.tp_group.cpu_group self.tp_cpu_group = self.tp_server.model_runner.tp_group.cpu_group
@@ -126,7 +123,6 @@ def start_controller_process(
server_args: ServerArgs, server_args: ServerArgs,
port_args: PortArgs, port_args: PortArgs,
pipe_writer: multiprocessing.connection.Connection, pipe_writer: multiprocessing.connection.Connection,
model_override_args: dict,
is_data_parallel_worker: bool = False, is_data_parallel_worker: bool = False,
gpu_ids: List[int] = None, gpu_ids: List[int] = None,
dp_worker_id: int = None, dp_worker_id: int = None,
@@ -149,7 +145,6 @@ def start_controller_process(
controller = ControllerSingle( controller = ControllerSingle(
server_args, server_args,
port_args, port_args,
model_override_args,
gpu_ids, gpu_ids,
is_data_parallel_worker, is_data_parallel_worker,
dp_worker_id, dp_worker_id,

View File

@@ -18,6 +18,7 @@ limitations under the License.
import asyncio import asyncio
import concurrent.futures import concurrent.futures
import dataclasses import dataclasses
import json
import logging import logging
import multiprocessing as mp import multiprocessing as mp
import os import os
@@ -77,7 +78,6 @@ class TokenizerManager:
self, self,
server_args: ServerArgs, server_args: ServerArgs,
port_args: PortArgs, port_args: PortArgs,
model_override_args: dict = None,
): ):
self.server_args = server_args self.server_args = server_args
@@ -95,7 +95,7 @@ class TokenizerManager:
self.hf_config = get_config( self.hf_config = get_config(
self.model_path, self.model_path,
trust_remote_code=server_args.trust_remote_code, trust_remote_code=server_args.trust_remote_code,
model_override_args=model_override_args, model_override_args=json.loads(server_args.json_model_override_args),
) )
self.is_generation = is_generation_model( self.is_generation = is_generation_model(
self.hf_config.architectures, self.server_args.is_embedding self.hf_config.architectures, self.server_args.is_embedding

View File

@@ -15,13 +15,14 @@ limitations under the License.
"""A tensor parallel worker.""" """A tensor parallel worker."""
import json
import logging import logging
import multiprocessing import multiprocessing
import os import os
import pickle import pickle
import time import time
import warnings import warnings
from typing import Any, List, Optional, Union from typing import Any, List, Optional
import torch import torch
import torch.distributed import torch.distributed
@@ -66,6 +67,7 @@ from sglang.utils import get_exception_traceback
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Crash on warning if we are running CI tests
crash_on_warning = os.getenv("SGLANG_IS_IN_CI", "false") == "true" crash_on_warning = os.getenv("SGLANG_IS_IN_CI", "false") == "true"
@@ -76,11 +78,10 @@ class ModelTpServer:
tp_rank: int, tp_rank: int,
server_args: ServerArgs, server_args: ServerArgs,
nccl_port: int, nccl_port: int,
model_override_args: dict,
): ):
suppress_other_loggers() suppress_other_loggers()
# Copy arguments # Parse arguments
self.gpu_id = gpu_id self.gpu_id = gpu_id
self.tp_rank = tp_rank self.tp_rank = tp_rank
self.tp_size = server_args.tp_size self.tp_size = server_args.tp_size
@@ -93,9 +94,8 @@ class ModelTpServer:
server_args.model_path, server_args.model_path,
server_args.trust_remote_code, server_args.trust_remote_code,
context_length=server_args.context_length, context_length=server_args.context_length,
model_override_args=model_override_args, model_override_args=json.loads(server_args.json_model_override_args),
) )
self.model_runner = ModelRunner( self.model_runner = ModelRunner(
model_config=self.model_config, model_config=self.model_config,
mem_fraction_static=server_args.mem_fraction_static, mem_fraction_static=server_args.mem_fraction_static,
@@ -136,7 +136,7 @@ class ModelTpServer:
self.max_total_num_tokens - 1, self.max_total_num_tokens - 1,
) )
# Sync random seed # Sync random seed across TP workers
server_args.random_seed = broadcast_recv_input( server_args.random_seed = broadcast_recv_input(
[server_args.random_seed], [server_args.random_seed],
self.tp_rank, self.tp_rank,
@@ -144,7 +144,7 @@ class ModelTpServer:
)[0] )[0]
set_random_seed(server_args.random_seed) set_random_seed(server_args.random_seed)
# Print info # Print debug info
logger.info( logger.info(
f"max_total_num_tokens={self.max_total_num_tokens}, " f"max_total_num_tokens={self.max_total_num_tokens}, "
f"max_prefill_tokens={self.max_prefill_tokens}, " f"max_prefill_tokens={self.max_prefill_tokens}, "
@@ -181,7 +181,7 @@ class ModelTpServer:
self.num_generated_tokens = 0 self.num_generated_tokens = 0
self.last_stats_tic = time.time() self.last_stats_tic = time.time()
# Chunked prefill # Init chunked prefill
self.chunked_prefill_size = server_args.chunked_prefill_size self.chunked_prefill_size = server_args.chunked_prefill_size
self.current_inflight_req = None self.current_inflight_req = None
self.is_mixed_chunk = ( self.is_mixed_chunk = (
@@ -197,16 +197,6 @@ class ModelTpServer:
"trust_remote_code": server_args.trust_remote_code, "trust_remote_code": server_args.trust_remote_code,
}, },
skip_tokenizer_init=server_args.skip_tokenizer_init, skip_tokenizer_init=server_args.skip_tokenizer_init,
json_schema_mode=False,
)
self.json_fsm_cache = FSMCache(
server_args.tokenizer_path,
{
"tokenizer_mode": server_args.tokenizer_mode,
"trust_remote_code": server_args.trust_remote_code,
},
skip_tokenizer_init=server_args.skip_tokenizer_init,
json_schema_mode=True,
) )
self.jump_forward_cache = JumpForwardCache() self.jump_forward_cache = JumpForwardCache()
@@ -227,11 +217,12 @@ class ModelTpServer:
try: try:
# Recv requests # Recv requests
for recv_req in recv_reqs: for recv_req in recv_reqs:
if isinstance( if isinstance(recv_req, TokenizedGenerateReqInput):
recv_req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput)
):
self.handle_generate_request(recv_req) self.handle_generate_request(recv_req)
self.do_not_get_new_batch = False self.do_not_get_new_batch = False
elif isinstance(recv_req, TokenizedEmbeddingReqInput):
self.handle_embedding_request(recv_req)
self.do_not_get_new_batch = False
elif isinstance(recv_req, FlushCacheReq): elif isinstance(recv_req, FlushCacheReq):
self.flush_cache() self.flush_cache()
elif isinstance(recv_req, AbortReq): elif isinstance(recv_req, AbortReq):
@@ -331,57 +322,56 @@ class ModelTpServer:
def handle_generate_request( def handle_generate_request(
self, self,
recv_req: Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput], recv_req: TokenizedGenerateReqInput,
): ):
req = Req(recv_req.rid, recv_req.input_text, recv_req.input_ids) req = Req(recv_req.rid, recv_req.input_text, recv_req.input_ids)
req.tokenizer = self.tokenizer req.tokenizer = self.tokenizer
req.sampling_params = recv_req.sampling_params req.sampling_params = recv_req.sampling_params
if self.model_runner.is_generation: req.pixel_values = recv_req.pixel_values
req.pixel_values = recv_req.pixel_values if req.pixel_values is not None:
if req.pixel_values is not None: # Use image hash as fake token_ids, which is then used
# Use image hash as fake token_ids, which is then used # for prefix matching
# for prefix matching image_hash = hash(tuple(recv_req.image_hashes))
image_hash = hash(tuple(recv_req.image_hashes)) req.pad_value = [
req.pad_value = [ (image_hash) % self.model_config.vocab_size,
(image_hash) % self.model_config.vocab_size, (image_hash >> 16) % self.model_config.vocab_size,
(image_hash >> 16) % self.model_config.vocab_size, (image_hash >> 32) % self.model_config.vocab_size,
(image_hash >> 32) % self.model_config.vocab_size, (image_hash >> 64) % self.model_config.vocab_size,
(image_hash >> 64) % self.model_config.vocab_size, ]
] req.image_sizes = recv_req.image_sizes
req.image_sizes = recv_req.image_sizes (
( req.origin_input_ids,
req.origin_input_ids, req.image_offsets,
req.image_offsets, ) = self.model_runner.model.pad_input_ids(
) = self.model_runner.model.pad_input_ids( req.origin_input_ids_unpadded,
req.origin_input_ids_unpadded, req.pad_value,
req.pad_value, req.pixel_values,
req.pixel_values, req.image_sizes,
req.image_sizes, )
) # Only when pixel values is not None we have modalities
# Only when pixel values is not None we have modalities req.modalities = recv_req.modalites
req.modalities = recv_req.modalites req.return_logprob = recv_req.return_logprob
req.return_logprob = recv_req.return_logprob req.logprob_start_len = recv_req.logprob_start_len
req.logprob_start_len = recv_req.logprob_start_len req.top_logprobs_num = recv_req.top_logprobs_num
req.top_logprobs_num = recv_req.top_logprobs_num req.stream = recv_req.stream
req.stream = recv_req.stream
# Init regex fsm fron json # Init regex FSM
if (
req.sampling_params.json_schema is not None
or req.sampling_params.regex is not None
):
if req.sampling_params.json_schema is not None: if req.sampling_params.json_schema is not None:
req.regex_fsm, computed_regex_string = self.json_fsm_cache.query( req.regex_fsm, computed_regex_string = self.regex_fsm_cache.query(
req.sampling_params.json_schema ("json", req.sampling_params.json_schema)
) )
if not self.disable_regex_jump_forward:
req.jump_forward_map = self.jump_forward_cache.query(
computed_regex_string
)
# Init regex fsm
elif req.sampling_params.regex is not None: elif req.sampling_params.regex is not None:
req.regex_fsm = self.regex_fsm_cache.query(req.sampling_params.regex) req.regex_fsm, computed_regex_string = self.regex_fsm_cache.query(
if not self.disable_regex_jump_forward: ("regex", req.sampling_params.regex)
req.jump_forward_map = self.jump_forward_cache.query( )
req.sampling_params.regex if not self.disable_regex_jump_forward:
) req.jump_forward_map = self.jump_forward_cache.query(
computed_regex_string
)
# Truncate prompts that are too long # Truncate prompts that are too long
if len(req.origin_input_ids) >= self.max_req_input_len: if len(req.origin_input_ids) >= self.max_req_input_len:
@@ -390,16 +380,32 @@ class ModelTpServer:
"the max context length. Truncated!!!" "the max context length. Truncated!!!"
) )
req.origin_input_ids = req.origin_input_ids[: self.max_req_input_len] req.origin_input_ids = req.origin_input_ids[: self.max_req_input_len]
req.sampling_params.max_new_tokens = min(
(
req.sampling_params.max_new_tokens
if req.sampling_params.max_new_tokens is not None
else 1 << 30
),
self.max_req_input_len - 1 - len(req.origin_input_ids),
)
if self.model_runner.is_generation: self.waiting_queue.append(req)
req.sampling_params.max_new_tokens = min(
( def handle_embedding_request(
req.sampling_params.max_new_tokens self,
if req.sampling_params.max_new_tokens is not None recv_req: TokenizedEmbeddingReqInput,
else 1 << 30 ):
), req = Req(recv_req.rid, recv_req.input_text, recv_req.input_ids)
self.max_req_input_len - 1 - len(req.origin_input_ids), req.tokenizer = self.tokenizer
req.sampling_params = recv_req.sampling_params
# Truncate prompts that are too long
if len(req.origin_input_ids) >= self.max_req_input_len:
logger.warn(
"Request length is longer than the KV cache pool size or "
"the max context length. Truncated!!!"
) )
req.origin_input_ids = req.origin_input_ids[: self.max_req_input_len]
self.waiting_queue.append(req) self.waiting_queue.append(req)
@@ -892,7 +898,6 @@ def run_tp_server(
tp_rank: int, tp_rank: int,
server_args: ServerArgs, server_args: ServerArgs,
nccl_port: int, nccl_port: int,
model_override_args: dict,
): ):
"""Run a tensor parallel model server.""" """Run a tensor parallel model server."""
configure_logger(server_args, prefix=f" TP{tp_rank}") configure_logger(server_args, prefix=f" TP{tp_rank}")
@@ -903,7 +908,6 @@ def run_tp_server(
tp_rank, tp_rank,
server_args, server_args,
nccl_port, nccl_port,
model_override_args,
) )
tp_cpu_group = model_server.model_runner.tp_group.cpu_group tp_cpu_group = model_server.model_runner.tp_group.cpu_group
@@ -920,14 +924,13 @@ def launch_tp_servers(
tp_rank_range: List[int], tp_rank_range: List[int],
server_args: ServerArgs, server_args: ServerArgs,
nccl_port: int, nccl_port: int,
model_override_args: dict,
): ):
"""Launch multiple tensor parallel servers.""" """Launch multiple tensor parallel servers."""
procs = [] procs = []
for i in tp_rank_range: for i in tp_rank_range:
proc = multiprocessing.Process( proc = multiprocessing.Process(
target=run_tp_server, target=run_tp_server,
args=(gpu_ids[i], i, server_args, nccl_port, model_override_args), args=(gpu_ids[i], i, server_args, nccl_port),
) )
proc.start() proc.start()
procs.append(proc) procs.append(proc)

View File

@@ -18,6 +18,7 @@ limitations under the License.
import gc import gc
import importlib import importlib
import importlib.resources import importlib.resources
import json
import logging import logging
import pkgutil import pkgutil
from functools import lru_cache from functools import lru_cache

View File

@@ -272,7 +272,6 @@ async def retrieve_file_content(file_id: str):
def launch_server( def launch_server(
server_args: ServerArgs, server_args: ServerArgs,
model_override_args: Optional[dict] = None,
pipe_finish_writer: Optional[mp.connection.Connection] = None, pipe_finish_writer: Optional[mp.connection.Connection] = None,
): ):
"""Launch an HTTP server.""" """Launch an HTTP server."""
@@ -317,7 +316,6 @@ def launch_server(
tp_rank_range, tp_rank_range,
server_args, server_args,
ports[3], ports[3],
model_override_args,
) )
try: try:
@@ -328,7 +326,7 @@ def launch_server(
return return
# Launch processes # Launch processes
tokenizer_manager = TokenizerManager(server_args, port_args, model_override_args) tokenizer_manager = TokenizerManager(server_args, port_args)
if server_args.chat_template: if server_args.chat_template:
load_chat_template_for_openai_api(tokenizer_manager, server_args.chat_template) load_chat_template_for_openai_api(tokenizer_manager, server_args.chat_template)
pipe_controller_reader, pipe_controller_writer = mp.Pipe(duplex=False) pipe_controller_reader, pipe_controller_writer = mp.Pipe(duplex=False)
@@ -341,7 +339,7 @@ def launch_server(
proc_controller = mp.Process( proc_controller = mp.Process(
target=start_controller_process, target=start_controller_process,
args=(server_args, port_args, pipe_controller_writer, model_override_args), args=(server_args, port_args, pipe_controller_writer),
) )
proc_controller.start() proc_controller.start()
@@ -501,7 +499,6 @@ class Runtime:
def __init__( def __init__(
self, self,
log_level: str = "error", log_level: str = "error",
model_override_args: Optional[dict] = None,
*args, *args,
**kwargs, **kwargs,
): ):
@@ -525,7 +522,7 @@ class Runtime:
proc = mp.Process( proc = mp.Process(
target=launch_server, target=launch_server,
args=(self.server_args, model_override_args, pipe_writer), args=(self.server_args, pipe_writer),
) )
proc.start() proc.start()
pipe_writer.close() pipe_writer.close()

View File

@@ -76,6 +76,14 @@ class ServerArgs:
dp_size: int = 1 dp_size: int = 1
load_balance_method: str = "round_robin" load_balance_method: str = "round_robin"
# Distributed args
nccl_init_addr: Optional[str] = None
nnodes: int = 1
node_rank: Optional[int] = None
# Model override args in JSON
json_model_override_args: str = "{}"
# Optimization/debug options # Optimization/debug options
disable_flashinfer: bool = False disable_flashinfer: bool = False
disable_flashinfer_sampling: bool = False disable_flashinfer_sampling: bool = False
@@ -91,14 +99,6 @@ class ServerArgs:
enable_mla: bool = False enable_mla: bool = False
triton_attention_reduce_in_fp32: bool = False triton_attention_reduce_in_fp32: bool = False
# Distributed args
nccl_init_addr: Optional[str] = None
nnodes: int = 1
node_rank: Optional[int] = None
# Model override args in JSON
json_model_override_args: Optional[dict] = None
def __post_init__(self): def __post_init__(self):
if self.tokenizer_path is None: if self.tokenizer_path is None:
self.tokenizer_path = self.model_path self.tokenizer_path = self.model_path
@@ -385,6 +385,14 @@ class ServerArgs:
) )
parser.add_argument("--node-rank", type=int, help="The node rank.") parser.add_argument("--node-rank", type=int, help="The node rank.")
# Model override args
parser.add_argument(
"--json-model-override-args",
type=str,
help="A dictionary in JSON string format used to override default model configurations.",
default=ServerArgs.json_model_override_args,
)
# Optimization/debug options # Optimization/debug options
parser.add_argument( parser.add_argument(
"--disable-flashinfer", "--disable-flashinfer",
@@ -459,22 +467,10 @@ class ServerArgs:
help="Turn on memory efficient weight loading with quantization (quantize per layer during loading).", help="Turn on memory efficient weight loading with quantization (quantize per layer during loading).",
) )
# Model override args
parser.add_argument(
"--json-model-override-args",
type=str,
help="A dictionary in JSON string format used to override default model configurations.",
)
@classmethod @classmethod
def from_cli_args(cls, args: argparse.Namespace): def from_cli_args(cls, args: argparse.Namespace):
args.tp_size = args.tensor_parallel_size args.tp_size = args.tensor_parallel_size
args.dp_size = args.data_parallel_size args.dp_size = args.data_parallel_size
args.json_model_override_args = (
json.loads(args.json_model_override_args)
if args.json_model_override_args
else None
)
attrs = [attr.name for attr in dataclasses.fields(cls)] attrs = [attr.name for attr in dataclasses.fields(cls)]
return cls(**{attr: getattr(args, attr) for attr in attrs}) return cls(**{attr: getattr(args, attr) for attr in attrs})
@@ -498,7 +494,7 @@ class ServerArgs:
self.disable_flashinfer = False self.disable_flashinfer = False
def prepare_server_args(args: argparse.Namespace) -> ServerArgs: def prepare_server_args(argv: List[str]) -> ServerArgs:
""" """
Prepare the server arguments from the command line arguments. Prepare the server arguments from the command line arguments.
@@ -511,7 +507,7 @@ def prepare_server_args(args: argparse.Namespace) -> ServerArgs:
""" """
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
ServerArgs.add_cli_args(parser) ServerArgs.add_cli_args(parser)
raw_args = parser.parse_args(args) raw_args = parser.parse_args(argv)
server_args = ServerArgs.from_cli_args(raw_args) server_args = ServerArgs.from_cli_args(raw_args)
return server_args return server_args

View File

@@ -0,0 +1,132 @@
"""
Run few-shot GSM-8K evaluation.
Usage:
python3 -m sglang.test.few_shot_gsm8k --num-questions 200
"""
import argparse
import ast
import re
import time
import numpy as np
from sglang.api import set_default_backend
from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint
from sglang.utils import download_and_cache_file, dump_state_text, read_jsonl
INVALID = -9999999
def get_one_example(lines, i, include_answer):
ret = "Question: " + lines[i]["question"] + "\nAnswer:"
if include_answer:
ret += " " + lines[i]["answer"]
return ret
def get_few_shot_examples(lines, k):
ret = ""
for i in range(k):
ret += get_one_example(lines, i, True) + "\n\n"
return ret
def get_answer_value(answer_str):
answer_str = answer_str.replace(",", "")
numbers = re.findall(r"\d+", answer_str)
if len(numbers) < 1:
return INVALID
try:
return ast.literal_eval(numbers[-1])
except SyntaxError:
return INVALID
def main(args):
# Select backend
set_default_backend(RuntimeEndpoint(f"{args.host}:{args.port}"))
# Read data
url = "https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl"
filename = download_and_cache_file(url)
lines = list(read_jsonl(filename))
# Construct prompts
num_questions = args.num_questions
num_shots = args.num_shots
few_shot_examples = get_few_shot_examples(lines, num_shots)
questions = []
labels = []
for i in range(len(lines[:num_questions])):
questions.append(get_one_example(lines, i, False))
labels.append(get_answer_value(lines[i]["answer"]))
assert all(l != INVALID for l in labels)
arguments = [{"question": q} for q in questions]
#####################################
######### SGL Program Begin #########
#####################################
import sglang as sgl
@sgl.function
def few_shot_gsm8k(s, question):
s += few_shot_examples + question
s += sgl.gen(
"answer", max_tokens=512, stop=["Question", "Assistant:", "<|separator|>"]
)
#####################################
########## SGL Program End ##########
#####################################
# Run requests
tic = time.time()
states = few_shot_gsm8k.run_batch(
arguments,
temperature=0,
num_threads=args.parallel,
progress_bar=True,
)
latency = time.time() - tic
preds = []
for i in range(len(states)):
preds.append(get_answer_value(states[i]["answer"]))
# print(f"{preds=}")
# print(f"{labels=}")
# Compute accuracy
acc = np.mean(np.array(preds) == np.array(labels))
invalid = np.mean(np.array(preds) == INVALID)
# Compute speed
num_output_tokens = sum(
s.get_meta_info("answer")["completion_tokens"] for s in states
)
output_throughput = num_output_tokens / latency
# Print results
print(f"Accuracy: {acc:.3f}")
print(f"Invalid: {invalid:.3f}")
print(f"Latency: {latency:.3f} s")
print(f"Output throughput: {output_throughput:.3f} token/s")
# Dump results
dump_state_text("tmp_output_gsm8k.txt", states)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--num-shots", type=int, default=5)
parser.add_argument("--data-path", type=str, default="test.jsonl")
parser.add_argument("--num-questions", type=int, default=200)
parser.add_argument("--parallel", type=int, default=128)
parser.add_argument("--host", type=str, default="http://127.0.0.1")
parser.add_argument("--port", type=int, default=30000)
args = parser.parse_args()
main(args)

View File

@@ -7,7 +7,7 @@ import time
import numpy as np import numpy as np
import sglang as sgl import sglang as sgl
from sglang.utils import fetch_and_cache_jsonl from sglang.utils import download_and_cache_file, read_jsonl
def test_few_shot_qa(): def test_few_shot_qa():
@@ -456,10 +456,6 @@ def test_chat_completion_speculative():
def test_hellaswag_select(): def test_hellaswag_select():
"""Benchmark the accuracy of sgl.select on the HellaSwag dataset.""" """Benchmark the accuracy of sgl.select on the HellaSwag dataset."""
url = "https://raw.githubusercontent.com/rowanz/hellaswag/master/data/hellaswag_val.jsonl"
lines = fetch_and_cache_jsonl(url)
# Construct prompts
def get_one_example(lines, i, include_answer): def get_one_example(lines, i, include_answer):
ret = lines[i]["activity_label"] + ": " + lines[i]["ctx"] + " " ret = lines[i]["activity_label"] + ": " + lines[i]["ctx"] + " "
if include_answer: if include_answer:
@@ -472,6 +468,12 @@ def test_hellaswag_select():
ret += get_one_example(lines, i, True) + "\n\n" ret += get_one_example(lines, i, True) + "\n\n"
return ret return ret
# Read data
url = "https://raw.githubusercontent.com/rowanz/hellaswag/master/data/hellaswag_val.jsonl"
filename = download_and_cache_file(url)
lines = list(read_jsonl(filename))
# Construct prompts
num_questions = 200 num_questions = 200
num_shots = 20 num_shots = 20
few_shot_examples = get_few_shot_examples(lines, num_shots) few_shot_examples = get_few_shot_examples(lines, num_shots)

View File

@@ -12,7 +12,7 @@ import urllib.request
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from io import BytesIO from io import BytesIO
from json import dumps from json import dumps
from typing import Union from typing import Optional, Union
import numpy as np import numpy as np
import requests import requests
@@ -38,13 +38,11 @@ def is_same_type(values: list):
def read_jsonl(filename: str): def read_jsonl(filename: str):
"""Read a JSONL file.""" """Read a JSONL file."""
rets = []
with open(filename) as fin: with open(filename) as fin:
for line in fin: for line in fin:
if line.startswith("#"): if line.startswith("#"):
continue continue
rets.append(json.loads(line)) yield json.loads(line)
return rets
def dump_state_text(filename: str, states: list, mode: str = "w"): def dump_state_text(filename: str, states: list, mode: str = "w"):
@@ -264,38 +262,35 @@ class LazyImport:
return module(*args, **kwargs) return module(*args, **kwargs)
def fetch_and_cache_jsonl(url, cache_file="cached_data.jsonl"): def download_and_cache_file(url: str, filename: Optional[str] = None):
"""Read and cache a jsonl file from a url.""" """Read and cache a file from a url."""
if filename is None:
filename = os.path.join("/tmp", url.split("/")[-1])
# Check if the cache file already exists # Check if the cache file already exists
if os.path.exists(cache_file): if os.path.exists(filename):
print("Loading data from cache...") return filename
with open(cache_file, "r") as f:
data = [json.loads(line) for line in f]
else:
print("Downloading data from URL...")
# Stream the response to show the progress bar
response = requests.get(url, stream=True)
response.raise_for_status() # Check for request errors
# Total size of the file in bytes print(f"Downloading from {url} to {filename}")
total_size = int(response.headers.get("content-length", 0))
chunk_size = 1024 # Download in chunks of 1KB
# Use tqdm to display the progress bar # Stream the response to show the progress bar
with open(cache_file, "wb") as f, tqdm( response = requests.get(url, stream=True)
desc=cache_file, response.raise_for_status() # Check for request errors
total=total_size,
unit="B",
unit_scale=True,
unit_divisor=1024,
) as bar:
for chunk in response.iter_content(chunk_size=chunk_size):
f.write(chunk)
bar.update(len(chunk))
# Convert the data to a list of dictionaries # Total size of the file in bytes
with open(cache_file, "r") as f: total_size = int(response.headers.get("content-length", 0))
data = [json.loads(line) for line in f] chunk_size = 1024 # Download in chunks of 1KB
return data # Use tqdm to display the progress bar
with open(filename, "wb") as f, tqdm(
desc=filename,
total=total_size,
unit="B",
unit_scale=True,
unit_divisor=1024,
) as bar:
for chunk in response.iter_content(chunk_size=chunk_size):
f.write(chunk)
bar.update(len(chunk))
return filename

View File

@@ -42,7 +42,7 @@ class TestEvalAccuracyLarge(unittest.TestCase):
) )
metrics = run_eval(args) metrics = run_eval(args)
assert metrics["score"] >= 0.63, f"{metrics}" assert metrics["score"] >= 0.62, f"{metrics}"
def test_human_eval(self): def test_human_eval(self):
args = SimpleNamespace( args = SimpleNamespace(
@@ -66,7 +66,7 @@ class TestEvalAccuracyLarge(unittest.TestCase):
) )
metrics = run_eval(args) metrics = run_eval(args)
assert metrics["score"] >= 0.63, f"{metrics}" assert metrics["score"] >= 0.62, f"{metrics}"
if __name__ == "__main__": if __name__ == "__main__":

View File

@@ -1,3 +1,4 @@
import json
import unittest import unittest
from sglang.srt.server_args import prepare_server_args from sglang.srt.server_args import prepare_server_args
@@ -15,7 +16,7 @@ class TestPrepareServerArgs(unittest.TestCase):
) )
self.assertEqual(server_args.model_path, "model_path") self.assertEqual(server_args.model_path, "model_path")
self.assertEqual( self.assertEqual(
server_args.json_model_override_args, json.loads(server_args.json_model_override_args),
{"rope_scaling": {"factor": 2.0, "type": "linear"}}, {"rope_scaling": {"factor": 2.0, "type": "linear"}},
) )