[Minor] Many cleanup (#1357)

This commit is contained in:
Lianmin Zheng
2024-09-09 04:14:11 -07:00
committed by GitHub
parent c9b75917d5
commit e4d68afcf0
24 changed files with 416 additions and 296 deletions

View File

@@ -1,8 +1,3 @@
## Download data
```
bash download_data.sh
```
## Run benchmark
### Benchmark sglang

View File

@@ -10,7 +10,7 @@ import numpy as np
from tqdm import tqdm
from sglang.test.test_utils import add_common_other_args_and_parse, get_call_generate
from sglang.utils import dump_state_text, read_jsonl
from sglang.utils import download_and_cache_file, dump_state_text, read_jsonl
INVALID = -9999999
@@ -41,24 +41,28 @@ def get_answer_value(answer_str):
def main(args):
lines = read_jsonl(args.data_path)
# Select backend
call_generate = get_call_generate(args)
# Read data
url = "https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl"
filename = download_and_cache_file(url)
lines = list(read_jsonl(filename))
# Construct prompts
k = args.num_shot
few_shot_examples = get_few_shot_examples(lines, k)
num_questions = args.num_questions
num_shots = args.num_shots
few_shot_examples = get_few_shot_examples(lines, num_shots)
questions = []
labels = []
for i in range(len(lines[: args.num_questions])):
for i in range(len(lines[:num_questions])):
questions.append(get_one_example(lines, i, False))
labels.append(get_answer_value(lines[i]["answer"]))
assert all(l != INVALID for l in labels)
states = [None] * len(labels)
# Select backend
call_generate = get_call_generate(args)
# Run requests
if args.backend != "lmql":
# Use thread pool
@@ -113,11 +117,13 @@ def main(args):
# Compute accuracy
acc = np.mean(np.array(preds) == np.array(labels))
invalid = np.mean(np.array(preds) == INVALID)
print(f"Latency: {latency:.3f}")
print(f"Invalid: {invalid:.3f}")
print(f"Accuracy: {acc:.3f}")
# Write results
# Print results
print(f"Accuracy: {acc:.3f}")
print(f"Invalid: {invalid:.3f}")
print(f"Latency: {latency:.3f} s")
# Dump results
dump_state_text(f"tmp_output_{args.backend}.txt", states)
with open(args.result_file, "a") as fout:
@@ -138,7 +144,7 @@ def main(args):
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--num-shot", type=int, default=5)
parser.add_argument("--num-shots", type=int, default=5)
parser.add_argument("--data-path", type=str, default="test.jsonl")
parser.add_argument("--num-questions", type=int, default=200)
args = add_common_other_args_and_parse(parser)

View File

@@ -6,11 +6,12 @@ import time
import numpy as np
from sglang.api import set_default_backend
from sglang.test.test_utils import (
add_common_sglang_args_and_parse,
select_sglang_backend,
)
from sglang.utils import dump_state_text, read_jsonl
from sglang.utils import download_and_cache_file, dump_state_text, read_jsonl
INVALID = -9999999
@@ -41,15 +42,22 @@ def get_answer_value(answer_str):
def main(args):
lines = read_jsonl(args.data_path)
# Select backend
set_default_backend(select_sglang_backend(args))
# Read data
url = "https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl"
filename = download_and_cache_file(url)
lines = list(read_jsonl(filename))
# Construct prompts
k = args.num_shot
few_shot_examples = get_few_shot_examples(lines, k)
num_questions = args.num_questions
num_shots = args.num_shots
few_shot_examples = get_few_shot_examples(lines, num_shots)
questions = []
labels = []
for i in range(len(lines[: args.num_questions])):
for i in range(len(lines[:num_questions])):
questions.append(get_one_example(lines, i, False))
labels.append(get_answer_value(lines[i]["answer"]))
assert all(l != INVALID for l in labels)
@@ -72,15 +80,11 @@ def main(args):
########## SGL Program End ##########
#####################################
# Select backend
backend = select_sglang_backend(args)
# Run requests
tic = time.time()
states = few_shot_gsm8k.run_batch(
arguments,
temperature=0,
backend=backend,
num_threads=args.parallel,
progress_bar=True,
)
@@ -96,11 +100,20 @@ def main(args):
# Compute accuracy
acc = np.mean(np.array(preds) == np.array(labels))
invalid = np.mean(np.array(preds) == INVALID)
print(f"Latency: {latency:.3f}")
print(f"Invalid: {invalid:.3f}")
print(f"Accuracy: {acc:.3f}")
# Write results
# Compute speed
num_output_tokens = sum(
s.get_meta_info("answer")["completion_tokens"] for s in states
)
output_throughput = num_output_tokens / latency
# Print results
print(f"Accuracy: {acc:.3f}")
print(f"Invalid: {invalid:.3f}")
print(f"Latency: {latency:.3f} s")
print(f"Output throughput: {output_throughput:.3f} token/s")
# Dump results
dump_state_text(f"tmp_output_{args.backend}.txt", states)
with open(args.result_file, "a") as fout:
@@ -121,7 +134,7 @@ def main(args):
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--num-shot", type=int, default=5)
parser.add_argument("--num-shots", type=int, default=5)
parser.add_argument("--data-path", type=str, default="test.jsonl")
parser.add_argument("--num-questions", type=int, default=200)
args = add_common_sglang_args_and_parse(parser)

View File

@@ -1,2 +0,0 @@
wget https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/train.jsonl
wget https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl

View File

@@ -1,8 +1,3 @@
## Download data
```
wget https://raw.githubusercontent.com/rowanz/hellaswag/master/data/hellaswag_val.jsonl
```
## Run benchmark
### Benchmark sglang

View File

@@ -8,7 +8,7 @@ import numpy as np
from tqdm import tqdm
from sglang.test.test_utils import add_common_other_args_and_parse, get_call_select
from sglang.utils import read_jsonl
from sglang.utils import download_and_cache_file, read_jsonl
def get_one_example(lines, i, include_answer):
@@ -26,25 +26,29 @@ def get_few_shot_examples(lines, k):
def main(args):
lines = read_jsonl(args.data_path)
# Select backend
call_select = get_call_select(args)
# Read data
url = "https://raw.githubusercontent.com/rowanz/hellaswag/master/data/hellaswag_val.jsonl"
filename = download_and_cache_file(url)
lines = list(read_jsonl(filename))
# Construct prompts
k = args.num_shot
few_shot_examples = get_few_shot_examples(lines, k)
num_questions = args.num_questions
num_shots = args.num_shots
few_shot_examples = get_few_shot_examples(lines, num_shots)
questions = []
choices = []
labels = []
for i in range(len(lines[: args.num_questions])):
for i in range(len(lines[:num_questions])):
questions.append(get_one_example(lines, i, False))
choices.append(lines[i]["endings"])
labels.append(lines[i]["label"])
preds = [None] * len(labels)
# Select backend
call_select = get_call_select(args)
# Run requests
if args.backend != "lmql":
# Use thread pool
@@ -65,7 +69,6 @@ def main(args):
total=len(questions),
)
)
else:
# Use asyncio
async def batched_call(batch_size):
@@ -108,7 +111,7 @@ def main(args):
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--num-shot", type=int, default=20)
parser.add_argument("--num-shots", type=int, default=20)
parser.add_argument("--data-path", type=str, default="hellaswag_val.jsonl")
parser.add_argument("--num-questions", type=int, default=200)
args = add_common_other_args_and_parse(parser)

View File

@@ -4,11 +4,12 @@ import time
import numpy as np
from sglang.api import set_default_backend
from sglang.test.test_utils import (
add_common_sglang_args_and_parse,
select_sglang_backend,
)
from sglang.utils import read_jsonl
from sglang.utils import download_and_cache_file, read_jsonl
def get_one_example(lines, i, include_answer):
@@ -26,16 +27,23 @@ def get_few_shot_examples(lines, k):
def main(args):
lines = read_jsonl(args.data_path)
# Select backend
set_default_backend(select_sglang_backend(args))
# Read data
url = "https://raw.githubusercontent.com/rowanz/hellaswag/master/data/hellaswag_val.jsonl"
filename = download_and_cache_file(url)
lines = list(read_jsonl(filename))
# Construct prompts
k = args.num_shot
few_shot_examples = get_few_shot_examples(lines, k)
num_questions = args.num_questions
num_shots = args.num_shots
few_shot_examples = get_few_shot_examples(lines, num_shots)
questions = []
choices = []
labels = []
for i in range(len(lines[: args.num_questions])):
for i in range(len(lines[:num_questions])):
questions.append(get_one_example(lines, i, False))
choices.append(lines[i]["endings"])
labels.append(lines[i]["label"])
@@ -56,15 +64,11 @@ def main(args):
########## SGL Program End ##########
#####################################
# Select backend
backend = select_sglang_backend(args)
# Run requests
tic = time.time()
rets = few_shot_hellaswag.run_batch(
arguments,
temperature=0,
backend=backend,
num_threads=args.parallel,
progress_bar=True,
)
@@ -95,7 +99,7 @@ def main(args):
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--num-shot", type=int, default=20)
parser.add_argument("--num-shots", type=int, default=20)
parser.add_argument("--data-path", type=str, default="hellaswag_val.jsonl")
parser.add_argument("--num-questions", type=int, default=200)
args = add_common_sglang_args_and_parse(parser)

View File

@@ -7,6 +7,7 @@ python3 srt_example_llava_v.py
import argparse
import csv
import json
import os
import time
@@ -223,7 +224,7 @@ if __name__ == "__main__":
tokenizer_path=tokenizer_path,
port=cur_port,
additional_ports=[cur_port + 1, cur_port + 2, cur_port + 3, cur_port + 4],
model_override_args=model_override_args,
json_model_override_args=json.dumps(model_override_args),
tp_size=1,
)
sgl.set_default_backend(runtime)

View File

@@ -298,34 +298,41 @@ class BenchmarkMetrics:
median_e2e_latency_ms: float
default_sharegpt_path = "ShareGPT_V3_unfiltered_cleaned_split.json"
SHAREGPT_URL = "https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json"
def download_sharegpt_dataset(path):
url = "https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json"
def download_and_cache_file(url: str, filename: Optional[str] = None):
"""Read and cache a file from a url."""
if filename is None:
filename = os.path.join("/tmp", url.split("/")[-1])
print(f"Downloading dataset from {url}")
try:
response = requests.get(url, stream=True)
response.raise_for_status()
# Check if the cache file already exists
if os.path.exists(filename):
return filename
total_size = int(response.headers.get("content-length", 0))
block_size = 8192
print(f"Downloading from {url} to {filename}")
with open(path, "wb") as f, tqdm(
desc="Downloading",
total=total_size,
unit="iB",
unit_scale=True,
unit_divisor=1024,
) as progress_bar:
for data in response.iter_content(block_size):
size = f.write(data)
progress_bar.update(size)
# Stream the response to show the progress bar
response = requests.get(url, stream=True)
response.raise_for_status() # Check for request errors
print(f"Dataset downloaded and saved to {path}")
except requests.RequestException as e:
raise Exception(f"Failed to download dataset: {e}")
# Total size of the file in bytes
total_size = int(response.headers.get("content-length", 0))
chunk_size = 1024 # Download in chunks of 1KB
# Use tqdm to display the progress bar
with open(filename, "wb") as f, tqdm(
desc=filename,
total=total_size,
unit="B",
unit_scale=True,
unit_divisor=1024,
) as bar:
for chunk in response.iter_content(chunk_size=chunk_size):
f.write(chunk)
bar.update(len(chunk))
return filename
def sample_sharegpt_requests(
@@ -338,13 +345,8 @@ def sample_sharegpt_requests(
raise ValueError("output_len too small")
# Download sharegpt if necessary
if not os.path.isfile(dataset_path) and not os.path.isfile(default_sharegpt_path):
download_sharegpt_dataset(default_sharegpt_path)
dataset_path = default_sharegpt_path
else:
dataset_path = (
dataset_path if os.path.isfile(dataset_path) else default_sharegpt_path
)
if not os.path.isfile(dataset_path):
dataset_path = download_and_cache_file(SHAREGPT_URL)
# Load the dataset.
with open(dataset_path) as f:
@@ -412,15 +414,8 @@ def sample_random_requests(
# Sample token ids from ShareGPT and repeat/truncate them to satisfy the input_lens
# Download sharegpt if necessary
if not os.path.isfile(dataset_path) and not os.path.isfile(
default_sharegpt_path
):
download_sharegpt_dataset(default_sharegpt_path)
dataset_path = default_sharegpt_path
else:
dataset_path = (
dataset_path if os.path.isfile(dataset_path) else default_sharegpt_path
)
if not os.path.isfile(dataset_path):
dataset_path = download_and_cache_file(SHAREGPT_URL)
# Load the dataset.
with open(dataset_path) as f:

View File

@@ -9,10 +9,9 @@ from sglang.srt.utils import kill_child_process
if __name__ == "__main__":
server_args = prepare_server_args(sys.argv[1:])
model_override_args = server_args.json_model_override_args
try:
launch_server(server_args, model_override_args=model_override_args)
launch_server(server_args)
except Exception as e:
raise e
finally:

View File

@@ -1,5 +1,6 @@
"""Launch the inference server for Llava-video model."""
import json
import sys
from sglang.srt.server import launch_server, prepare_server_args
@@ -19,5 +20,6 @@ if __name__ == "__main__":
model_override_args["model_max_length"] = 4096 * 2
if "34b" in server_args.model_path.lower():
model_override_args["image_token_index"] = 64002
server_args.json_model_override_args = json.dumps(model_override_args)
launch_server(server_args, model_override_args, None)
launch_server(server_args)

View File

@@ -16,6 +16,7 @@ limitations under the License.
"""Cache for the compressed finite state machine."""
from outlines.fsm.json_schema import build_regex_from_schema
from transformers import AutoTokenizer
from sglang.srt.constrained import RegexGuide, TransformerTokenizer
from sglang.srt.constrained.base_tool_cache import BaseToolCache
@@ -28,12 +29,9 @@ class FSMCache(BaseToolCache):
tokenizer_args_dict,
enable=True,
skip_tokenizer_init=False,
json_schema_mode=False,
):
super().__init__(enable=enable)
self.json_schema_mode = json_schema_mode
if (
skip_tokenizer_init
or tokenizer_path.endswith(".json")
@@ -42,44 +40,37 @@ class FSMCache(BaseToolCache):
# Do not support TiktokenTokenizer or SentencePieceTokenizer
return
from importlib.metadata import version
tokenizer_args_dict.setdefault("padding_side", "left")
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, **tokenizer_args_dict)
try:
self.outlines_tokenizer = TransformerTokenizer(tokenizer)
except AttributeError:
# FIXME: tmp fix for chatglm2 & chatglm3 (pad_token_id=0)
origin_pad_token_id = tokenizer.pad_token_id
if version("outlines") >= "0.0.35":
from transformers import AutoTokenizer
def fset(self, value):
self._value = value
tokenizer_args_dict.setdefault("padding_side", "left")
tokenizer = AutoTokenizer.from_pretrained(
tokenizer_path, **tokenizer_args_dict
type(tokenizer).pad_token_id = property(
fget=type(tokenizer).pad_token_id.fget, fset=fset
)
try:
self.outlines_tokenizer = TransformerTokenizer(tokenizer)
except AttributeError:
# FIXME: tmp fix for chatglm2 & chatglm3 (pad_token_id=0)
origin_pad_token_id = tokenizer.pad_token_id
def fset(self, value):
self._value = value
type(tokenizer).pad_token_id = property(
fget=type(tokenizer).pad_token_id.fget, fset=fset
)
self.outlines_tokenizer = TransformerTokenizer(tokenizer)
self.outlines_tokenizer.tokenizer.pad_token_id = origin_pad_token_id
self.outlines_tokenizer.pad_token_id = origin_pad_token_id
self.outlines_tokenizer.pad_token = (
self.outlines_tokenizer.tokenizer.pad_token
)
self.outlines_tokenizer.vocabulary = (
self.outlines_tokenizer.tokenizer.get_vocab()
)
else:
self.outlines_tokenizer = TransformerTokenizer(
tokenizer_path, **tokenizer_args_dict
self.outlines_tokenizer = TransformerTokenizer(tokenizer)
self.outlines_tokenizer.tokenizer.pad_token_id = origin_pad_token_id
self.outlines_tokenizer.pad_token_id = origin_pad_token_id
self.outlines_tokenizer.pad_token = (
self.outlines_tokenizer.tokenizer.pad_token
)
self.outlines_tokenizer.vocabulary = (
self.outlines_tokenizer.tokenizer.get_vocab()
)
def init_value(self, value):
if self.json_schema_mode:
regex = build_regex_from_schema(value, whitespace_pattern=r"[\n\t ]*")
return RegexGuide(regex, self.outlines_tokenizer), regex
def init_value(self, key):
key_type, key_string = key
if key_type == "json":
regex = build_regex_from_schema(key_string, whitespace_pattern=r"[\n\t ]*")
elif key_type == "regex":
regex = key_string
else:
return RegexGuide(value, self.outlines_tokenizer)
raise ValueError(f"Invalid key_type: {key_type}")
return RegexGuide(regex, self.outlines_tokenizer), regex

View File

@@ -71,12 +71,10 @@ class ControllerMulti:
self,
server_args: ServerArgs,
port_args: PortArgs,
model_override_args,
):
# Parse args
self.server_args = server_args
self.port_args = port_args
self.model_override_args = model_override_args
self.load_balance_method = LoadBalanceMethod.from_str(
server_args.load_balance_method
)
@@ -114,7 +112,6 @@ class ControllerMulti:
self.server_args,
self.port_args,
pipe_controller_writer,
self.model_override_args,
True,
gpu_ids,
dp_worker_id,
@@ -189,14 +186,13 @@ def start_controller_process(
server_args: ServerArgs,
port_args: PortArgs,
pipe_writer,
model_override_args: dict,
):
"""Start a controller process."""
configure_logger(server_args)
try:
controller = ControllerMulti(server_args, port_args, model_override_args)
controller = ControllerMulti(server_args, port_args)
except Exception:
pipe_writer.send(get_exception_traceback())
raise

View File

@@ -40,7 +40,6 @@ class ControllerSingle:
self,
server_args: ServerArgs,
port_args: PortArgs,
model_override_args: dict,
gpu_ids: List[int],
is_data_parallel_worker: bool,
dp_worker_id: int,
@@ -76,7 +75,6 @@ class ControllerSingle:
tp_rank_range,
server_args,
port_args.nccl_ports[dp_worker_id],
model_override_args,
)
# Launch tp rank 0
@@ -85,7 +83,6 @@ class ControllerSingle:
0,
server_args,
port_args.nccl_ports[dp_worker_id],
model_override_args,
)
self.tp_cpu_group = self.tp_server.model_runner.tp_group.cpu_group
@@ -126,7 +123,6 @@ def start_controller_process(
server_args: ServerArgs,
port_args: PortArgs,
pipe_writer: multiprocessing.connection.Connection,
model_override_args: dict,
is_data_parallel_worker: bool = False,
gpu_ids: List[int] = None,
dp_worker_id: int = None,
@@ -149,7 +145,6 @@ def start_controller_process(
controller = ControllerSingle(
server_args,
port_args,
model_override_args,
gpu_ids,
is_data_parallel_worker,
dp_worker_id,

View File

@@ -18,6 +18,7 @@ limitations under the License.
import asyncio
import concurrent.futures
import dataclasses
import json
import logging
import multiprocessing as mp
import os
@@ -77,7 +78,6 @@ class TokenizerManager:
self,
server_args: ServerArgs,
port_args: PortArgs,
model_override_args: dict = None,
):
self.server_args = server_args
@@ -95,7 +95,7 @@ class TokenizerManager:
self.hf_config = get_config(
self.model_path,
trust_remote_code=server_args.trust_remote_code,
model_override_args=model_override_args,
model_override_args=json.loads(server_args.json_model_override_args),
)
self.is_generation = is_generation_model(
self.hf_config.architectures, self.server_args.is_embedding

View File

@@ -15,13 +15,14 @@ limitations under the License.
"""A tensor parallel worker."""
import json
import logging
import multiprocessing
import os
import pickle
import time
import warnings
from typing import Any, List, Optional, Union
from typing import Any, List, Optional
import torch
import torch.distributed
@@ -66,6 +67,7 @@ from sglang.utils import get_exception_traceback
logger = logging.getLogger(__name__)
# Crash on warning if we are running CI tests
crash_on_warning = os.getenv("SGLANG_IS_IN_CI", "false") == "true"
@@ -76,11 +78,10 @@ class ModelTpServer:
tp_rank: int,
server_args: ServerArgs,
nccl_port: int,
model_override_args: dict,
):
suppress_other_loggers()
# Copy arguments
# Parse arguments
self.gpu_id = gpu_id
self.tp_rank = tp_rank
self.tp_size = server_args.tp_size
@@ -93,9 +94,8 @@ class ModelTpServer:
server_args.model_path,
server_args.trust_remote_code,
context_length=server_args.context_length,
model_override_args=model_override_args,
model_override_args=json.loads(server_args.json_model_override_args),
)
self.model_runner = ModelRunner(
model_config=self.model_config,
mem_fraction_static=server_args.mem_fraction_static,
@@ -136,7 +136,7 @@ class ModelTpServer:
self.max_total_num_tokens - 1,
)
# Sync random seed
# Sync random seed across TP workers
server_args.random_seed = broadcast_recv_input(
[server_args.random_seed],
self.tp_rank,
@@ -144,7 +144,7 @@ class ModelTpServer:
)[0]
set_random_seed(server_args.random_seed)
# Print info
# Print debug info
logger.info(
f"max_total_num_tokens={self.max_total_num_tokens}, "
f"max_prefill_tokens={self.max_prefill_tokens}, "
@@ -181,7 +181,7 @@ class ModelTpServer:
self.num_generated_tokens = 0
self.last_stats_tic = time.time()
# Chunked prefill
# Init chunked prefill
self.chunked_prefill_size = server_args.chunked_prefill_size
self.current_inflight_req = None
self.is_mixed_chunk = (
@@ -197,16 +197,6 @@ class ModelTpServer:
"trust_remote_code": server_args.trust_remote_code,
},
skip_tokenizer_init=server_args.skip_tokenizer_init,
json_schema_mode=False,
)
self.json_fsm_cache = FSMCache(
server_args.tokenizer_path,
{
"tokenizer_mode": server_args.tokenizer_mode,
"trust_remote_code": server_args.trust_remote_code,
},
skip_tokenizer_init=server_args.skip_tokenizer_init,
json_schema_mode=True,
)
self.jump_forward_cache = JumpForwardCache()
@@ -227,11 +217,12 @@ class ModelTpServer:
try:
# Recv requests
for recv_req in recv_reqs:
if isinstance(
recv_req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput)
):
if isinstance(recv_req, TokenizedGenerateReqInput):
self.handle_generate_request(recv_req)
self.do_not_get_new_batch = False
elif isinstance(recv_req, TokenizedEmbeddingReqInput):
self.handle_embedding_request(recv_req)
self.do_not_get_new_batch = False
elif isinstance(recv_req, FlushCacheReq):
self.flush_cache()
elif isinstance(recv_req, AbortReq):
@@ -331,57 +322,56 @@ class ModelTpServer:
def handle_generate_request(
self,
recv_req: Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput],
recv_req: TokenizedGenerateReqInput,
):
req = Req(recv_req.rid, recv_req.input_text, recv_req.input_ids)
req.tokenizer = self.tokenizer
req.sampling_params = recv_req.sampling_params
if self.model_runner.is_generation:
req.pixel_values = recv_req.pixel_values
if req.pixel_values is not None:
# Use image hash as fake token_ids, which is then used
# for prefix matching
image_hash = hash(tuple(recv_req.image_hashes))
req.pad_value = [
(image_hash) % self.model_config.vocab_size,
(image_hash >> 16) % self.model_config.vocab_size,
(image_hash >> 32) % self.model_config.vocab_size,
(image_hash >> 64) % self.model_config.vocab_size,
]
req.image_sizes = recv_req.image_sizes
(
req.origin_input_ids,
req.image_offsets,
) = self.model_runner.model.pad_input_ids(
req.origin_input_ids_unpadded,
req.pad_value,
req.pixel_values,
req.image_sizes,
)
# Only when pixel values is not None we have modalities
req.modalities = recv_req.modalites
req.return_logprob = recv_req.return_logprob
req.logprob_start_len = recv_req.logprob_start_len
req.top_logprobs_num = recv_req.top_logprobs_num
req.stream = recv_req.stream
req.pixel_values = recv_req.pixel_values
if req.pixel_values is not None:
# Use image hash as fake token_ids, which is then used
# for prefix matching
image_hash = hash(tuple(recv_req.image_hashes))
req.pad_value = [
(image_hash) % self.model_config.vocab_size,
(image_hash >> 16) % self.model_config.vocab_size,
(image_hash >> 32) % self.model_config.vocab_size,
(image_hash >> 64) % self.model_config.vocab_size,
]
req.image_sizes = recv_req.image_sizes
(
req.origin_input_ids,
req.image_offsets,
) = self.model_runner.model.pad_input_ids(
req.origin_input_ids_unpadded,
req.pad_value,
req.pixel_values,
req.image_sizes,
)
# Only when pixel values is not None we have modalities
req.modalities = recv_req.modalites
req.return_logprob = recv_req.return_logprob
req.logprob_start_len = recv_req.logprob_start_len
req.top_logprobs_num = recv_req.top_logprobs_num
req.stream = recv_req.stream
# Init regex fsm fron json
# Init regex FSM
if (
req.sampling_params.json_schema is not None
or req.sampling_params.regex is not None
):
if req.sampling_params.json_schema is not None:
req.regex_fsm, computed_regex_string = self.json_fsm_cache.query(
req.sampling_params.json_schema
req.regex_fsm, computed_regex_string = self.regex_fsm_cache.query(
("json", req.sampling_params.json_schema)
)
if not self.disable_regex_jump_forward:
req.jump_forward_map = self.jump_forward_cache.query(
computed_regex_string
)
# Init regex fsm
elif req.sampling_params.regex is not None:
req.regex_fsm = self.regex_fsm_cache.query(req.sampling_params.regex)
if not self.disable_regex_jump_forward:
req.jump_forward_map = self.jump_forward_cache.query(
req.sampling_params.regex
)
req.regex_fsm, computed_regex_string = self.regex_fsm_cache.query(
("regex", req.sampling_params.regex)
)
if not self.disable_regex_jump_forward:
req.jump_forward_map = self.jump_forward_cache.query(
computed_regex_string
)
# Truncate prompts that are too long
if len(req.origin_input_ids) >= self.max_req_input_len:
@@ -390,16 +380,32 @@ class ModelTpServer:
"the max context length. Truncated!!!"
)
req.origin_input_ids = req.origin_input_ids[: self.max_req_input_len]
req.sampling_params.max_new_tokens = min(
(
req.sampling_params.max_new_tokens
if req.sampling_params.max_new_tokens is not None
else 1 << 30
),
self.max_req_input_len - 1 - len(req.origin_input_ids),
)
if self.model_runner.is_generation:
req.sampling_params.max_new_tokens = min(
(
req.sampling_params.max_new_tokens
if req.sampling_params.max_new_tokens is not None
else 1 << 30
),
self.max_req_input_len - 1 - len(req.origin_input_ids),
self.waiting_queue.append(req)
def handle_embedding_request(
self,
recv_req: TokenizedEmbeddingReqInput,
):
req = Req(recv_req.rid, recv_req.input_text, recv_req.input_ids)
req.tokenizer = self.tokenizer
req.sampling_params = recv_req.sampling_params
# Truncate prompts that are too long
if len(req.origin_input_ids) >= self.max_req_input_len:
logger.warn(
"Request length is longer than the KV cache pool size or "
"the max context length. Truncated!!!"
)
req.origin_input_ids = req.origin_input_ids[: self.max_req_input_len]
self.waiting_queue.append(req)
@@ -892,7 +898,6 @@ def run_tp_server(
tp_rank: int,
server_args: ServerArgs,
nccl_port: int,
model_override_args: dict,
):
"""Run a tensor parallel model server."""
configure_logger(server_args, prefix=f" TP{tp_rank}")
@@ -903,7 +908,6 @@ def run_tp_server(
tp_rank,
server_args,
nccl_port,
model_override_args,
)
tp_cpu_group = model_server.model_runner.tp_group.cpu_group
@@ -920,14 +924,13 @@ def launch_tp_servers(
tp_rank_range: List[int],
server_args: ServerArgs,
nccl_port: int,
model_override_args: dict,
):
"""Launch multiple tensor parallel servers."""
procs = []
for i in tp_rank_range:
proc = multiprocessing.Process(
target=run_tp_server,
args=(gpu_ids[i], i, server_args, nccl_port, model_override_args),
args=(gpu_ids[i], i, server_args, nccl_port),
)
proc.start()
procs.append(proc)

View File

@@ -18,6 +18,7 @@ limitations under the License.
import gc
import importlib
import importlib.resources
import json
import logging
import pkgutil
from functools import lru_cache

View File

@@ -272,7 +272,6 @@ async def retrieve_file_content(file_id: str):
def launch_server(
server_args: ServerArgs,
model_override_args: Optional[dict] = None,
pipe_finish_writer: Optional[mp.connection.Connection] = None,
):
"""Launch an HTTP server."""
@@ -317,7 +316,6 @@ def launch_server(
tp_rank_range,
server_args,
ports[3],
model_override_args,
)
try:
@@ -328,7 +326,7 @@ def launch_server(
return
# Launch processes
tokenizer_manager = TokenizerManager(server_args, port_args, model_override_args)
tokenizer_manager = TokenizerManager(server_args, port_args)
if server_args.chat_template:
load_chat_template_for_openai_api(tokenizer_manager, server_args.chat_template)
pipe_controller_reader, pipe_controller_writer = mp.Pipe(duplex=False)
@@ -341,7 +339,7 @@ def launch_server(
proc_controller = mp.Process(
target=start_controller_process,
args=(server_args, port_args, pipe_controller_writer, model_override_args),
args=(server_args, port_args, pipe_controller_writer),
)
proc_controller.start()
@@ -501,7 +499,6 @@ class Runtime:
def __init__(
self,
log_level: str = "error",
model_override_args: Optional[dict] = None,
*args,
**kwargs,
):
@@ -525,7 +522,7 @@ class Runtime:
proc = mp.Process(
target=launch_server,
args=(self.server_args, model_override_args, pipe_writer),
args=(self.server_args, pipe_writer),
)
proc.start()
pipe_writer.close()

View File

@@ -76,6 +76,14 @@ class ServerArgs:
dp_size: int = 1
load_balance_method: str = "round_robin"
# Distributed args
nccl_init_addr: Optional[str] = None
nnodes: int = 1
node_rank: Optional[int] = None
# Model override args in JSON
json_model_override_args: str = "{}"
# Optimization/debug options
disable_flashinfer: bool = False
disable_flashinfer_sampling: bool = False
@@ -91,14 +99,6 @@ class ServerArgs:
enable_mla: bool = False
triton_attention_reduce_in_fp32: bool = False
# Distributed args
nccl_init_addr: Optional[str] = None
nnodes: int = 1
node_rank: Optional[int] = None
# Model override args in JSON
json_model_override_args: Optional[dict] = None
def __post_init__(self):
if self.tokenizer_path is None:
self.tokenizer_path = self.model_path
@@ -385,6 +385,14 @@ class ServerArgs:
)
parser.add_argument("--node-rank", type=int, help="The node rank.")
# Model override args
parser.add_argument(
"--json-model-override-args",
type=str,
help="A dictionary in JSON string format used to override default model configurations.",
default=ServerArgs.json_model_override_args,
)
# Optimization/debug options
parser.add_argument(
"--disable-flashinfer",
@@ -459,22 +467,10 @@ class ServerArgs:
help="Turn on memory efficient weight loading with quantization (quantize per layer during loading).",
)
# Model override args
parser.add_argument(
"--json-model-override-args",
type=str,
help="A dictionary in JSON string format used to override default model configurations.",
)
@classmethod
def from_cli_args(cls, args: argparse.Namespace):
args.tp_size = args.tensor_parallel_size
args.dp_size = args.data_parallel_size
args.json_model_override_args = (
json.loads(args.json_model_override_args)
if args.json_model_override_args
else None
)
attrs = [attr.name for attr in dataclasses.fields(cls)]
return cls(**{attr: getattr(args, attr) for attr in attrs})
@@ -498,7 +494,7 @@ class ServerArgs:
self.disable_flashinfer = False
def prepare_server_args(args: argparse.Namespace) -> ServerArgs:
def prepare_server_args(argv: List[str]) -> ServerArgs:
"""
Prepare the server arguments from the command line arguments.
@@ -511,7 +507,7 @@ def prepare_server_args(args: argparse.Namespace) -> ServerArgs:
"""
parser = argparse.ArgumentParser()
ServerArgs.add_cli_args(parser)
raw_args = parser.parse_args(args)
raw_args = parser.parse_args(argv)
server_args = ServerArgs.from_cli_args(raw_args)
return server_args

View File

@@ -0,0 +1,132 @@
"""
Run few-shot GSM-8K evaluation.
Usage:
python3 -m sglang.test.few_shot_gsm8k --num-questions 200
"""
import argparse
import ast
import re
import time
import numpy as np
from sglang.api import set_default_backend
from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint
from sglang.utils import download_and_cache_file, dump_state_text, read_jsonl
INVALID = -9999999
def get_one_example(lines, i, include_answer):
ret = "Question: " + lines[i]["question"] + "\nAnswer:"
if include_answer:
ret += " " + lines[i]["answer"]
return ret
def get_few_shot_examples(lines, k):
ret = ""
for i in range(k):
ret += get_one_example(lines, i, True) + "\n\n"
return ret
def get_answer_value(answer_str):
answer_str = answer_str.replace(",", "")
numbers = re.findall(r"\d+", answer_str)
if len(numbers) < 1:
return INVALID
try:
return ast.literal_eval(numbers[-1])
except SyntaxError:
return INVALID
def main(args):
# Select backend
set_default_backend(RuntimeEndpoint(f"{args.host}:{args.port}"))
# Read data
url = "https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl"
filename = download_and_cache_file(url)
lines = list(read_jsonl(filename))
# Construct prompts
num_questions = args.num_questions
num_shots = args.num_shots
few_shot_examples = get_few_shot_examples(lines, num_shots)
questions = []
labels = []
for i in range(len(lines[:num_questions])):
questions.append(get_one_example(lines, i, False))
labels.append(get_answer_value(lines[i]["answer"]))
assert all(l != INVALID for l in labels)
arguments = [{"question": q} for q in questions]
#####################################
######### SGL Program Begin #########
#####################################
import sglang as sgl
@sgl.function
def few_shot_gsm8k(s, question):
s += few_shot_examples + question
s += sgl.gen(
"answer", max_tokens=512, stop=["Question", "Assistant:", "<|separator|>"]
)
#####################################
########## SGL Program End ##########
#####################################
# Run requests
tic = time.time()
states = few_shot_gsm8k.run_batch(
arguments,
temperature=0,
num_threads=args.parallel,
progress_bar=True,
)
latency = time.time() - tic
preds = []
for i in range(len(states)):
preds.append(get_answer_value(states[i]["answer"]))
# print(f"{preds=}")
# print(f"{labels=}")
# Compute accuracy
acc = np.mean(np.array(preds) == np.array(labels))
invalid = np.mean(np.array(preds) == INVALID)
# Compute speed
num_output_tokens = sum(
s.get_meta_info("answer")["completion_tokens"] for s in states
)
output_throughput = num_output_tokens / latency
# Print results
print(f"Accuracy: {acc:.3f}")
print(f"Invalid: {invalid:.3f}")
print(f"Latency: {latency:.3f} s")
print(f"Output throughput: {output_throughput:.3f} token/s")
# Dump results
dump_state_text("tmp_output_gsm8k.txt", states)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--num-shots", type=int, default=5)
parser.add_argument("--data-path", type=str, default="test.jsonl")
parser.add_argument("--num-questions", type=int, default=200)
parser.add_argument("--parallel", type=int, default=128)
parser.add_argument("--host", type=str, default="http://127.0.0.1")
parser.add_argument("--port", type=int, default=30000)
args = parser.parse_args()
main(args)

View File

@@ -7,7 +7,7 @@ import time
import numpy as np
import sglang as sgl
from sglang.utils import fetch_and_cache_jsonl
from sglang.utils import download_and_cache_file, read_jsonl
def test_few_shot_qa():
@@ -456,10 +456,6 @@ def test_chat_completion_speculative():
def test_hellaswag_select():
"""Benchmark the accuracy of sgl.select on the HellaSwag dataset."""
url = "https://raw.githubusercontent.com/rowanz/hellaswag/master/data/hellaswag_val.jsonl"
lines = fetch_and_cache_jsonl(url)
# Construct prompts
def get_one_example(lines, i, include_answer):
ret = lines[i]["activity_label"] + ": " + lines[i]["ctx"] + " "
if include_answer:
@@ -472,6 +468,12 @@ def test_hellaswag_select():
ret += get_one_example(lines, i, True) + "\n\n"
return ret
# Read data
url = "https://raw.githubusercontent.com/rowanz/hellaswag/master/data/hellaswag_val.jsonl"
filename = download_and_cache_file(url)
lines = list(read_jsonl(filename))
# Construct prompts
num_questions = 200
num_shots = 20
few_shot_examples = get_few_shot_examples(lines, num_shots)

View File

@@ -12,7 +12,7 @@ import urllib.request
from concurrent.futures import ThreadPoolExecutor
from io import BytesIO
from json import dumps
from typing import Union
from typing import Optional, Union
import numpy as np
import requests
@@ -38,13 +38,11 @@ def is_same_type(values: list):
def read_jsonl(filename: str):
"""Read a JSONL file."""
rets = []
with open(filename) as fin:
for line in fin:
if line.startswith("#"):
continue
rets.append(json.loads(line))
return rets
yield json.loads(line)
def dump_state_text(filename: str, states: list, mode: str = "w"):
@@ -264,38 +262,35 @@ class LazyImport:
return module(*args, **kwargs)
def fetch_and_cache_jsonl(url, cache_file="cached_data.jsonl"):
"""Read and cache a jsonl file from a url."""
def download_and_cache_file(url: str, filename: Optional[str] = None):
"""Read and cache a file from a url."""
if filename is None:
filename = os.path.join("/tmp", url.split("/")[-1])
# Check if the cache file already exists
if os.path.exists(cache_file):
print("Loading data from cache...")
with open(cache_file, "r") as f:
data = [json.loads(line) for line in f]
else:
print("Downloading data from URL...")
# Stream the response to show the progress bar
response = requests.get(url, stream=True)
response.raise_for_status() # Check for request errors
if os.path.exists(filename):
return filename
# Total size of the file in bytes
total_size = int(response.headers.get("content-length", 0))
chunk_size = 1024 # Download in chunks of 1KB
print(f"Downloading from {url} to {filename}")
# Use tqdm to display the progress bar
with open(cache_file, "wb") as f, tqdm(
desc=cache_file,
total=total_size,
unit="B",
unit_scale=True,
unit_divisor=1024,
) as bar:
for chunk in response.iter_content(chunk_size=chunk_size):
f.write(chunk)
bar.update(len(chunk))
# Stream the response to show the progress bar
response = requests.get(url, stream=True)
response.raise_for_status() # Check for request errors
# Convert the data to a list of dictionaries
with open(cache_file, "r") as f:
data = [json.loads(line) for line in f]
# Total size of the file in bytes
total_size = int(response.headers.get("content-length", 0))
chunk_size = 1024 # Download in chunks of 1KB
return data
# Use tqdm to display the progress bar
with open(filename, "wb") as f, tqdm(
desc=filename,
total=total_size,
unit="B",
unit_scale=True,
unit_divisor=1024,
) as bar:
for chunk in response.iter_content(chunk_size=chunk_size):
f.write(chunk)
bar.update(len(chunk))
return filename

View File

@@ -42,7 +42,7 @@ class TestEvalAccuracyLarge(unittest.TestCase):
)
metrics = run_eval(args)
assert metrics["score"] >= 0.63, f"{metrics}"
assert metrics["score"] >= 0.62, f"{metrics}"
def test_human_eval(self):
args = SimpleNamespace(
@@ -66,7 +66,7 @@ class TestEvalAccuracyLarge(unittest.TestCase):
)
metrics = run_eval(args)
assert metrics["score"] >= 0.63, f"{metrics}"
assert metrics["score"] >= 0.62, f"{metrics}"
if __name__ == "__main__":

View File

@@ -1,3 +1,4 @@
import json
import unittest
from sglang.srt.server_args import prepare_server_args
@@ -15,7 +16,7 @@ class TestPrepareServerArgs(unittest.TestCase):
)
self.assertEqual(server_args.model_path, "model_path")
self.assertEqual(
server_args.json_model_override_args,
json.loads(server_args.json_model_override_args),
{"rope_scaling": {"factor": 2.0, "type": "linear"}},
)