From e4d68afcf00869a5467f101d176fecc3cd97b7b8 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Mon, 9 Sep 2024 04:14:11 -0700 Subject: [PATCH] [Minor] Many cleanup (#1357) --- benchmark/gsm8k/README.md | 5 - benchmark/gsm8k/bench_other.py | 32 ++-- benchmark/gsm8k/bench_sglang.py | 41 +++-- benchmark/gsm8k/download_data.sh | 2 - benchmark/hellaswag/README.md | 5 - benchmark/hellaswag/bench_other.py | 23 +-- benchmark/hellaswag/bench_sglang.py | 24 +-- .../usage/llava_video/srt_example_llava_v.py | 3 +- python/sglang/bench_serving.py | 71 ++++---- python/sglang/launch_server.py | 3 +- python/sglang/launch_server_llavavid.py | 4 +- python/sglang/srt/constrained/fsm_cache.py | 67 ++++---- .../sglang/srt/managers/controller_multi.py | 6 +- .../sglang/srt/managers/controller_single.py | 5 - .../sglang/srt/managers/tokenizer_manager.py | 4 +- python/sglang/srt/managers/tp_worker.py | 155 +++++++++--------- .../sglang/srt/model_executor/model_runner.py | 1 + python/sglang/srt/server.py | 9 +- python/sglang/srt/server_args.py | 40 ++--- python/sglang/test/few_shot_gsm8k.py | 132 +++++++++++++++ python/sglang/test/test_programs.py | 12 +- python/sglang/utils.py | 61 ++++--- test/srt/test_moe_eval_accuracy_large.py | 4 +- test/srt/test_server_args.py | 3 +- 24 files changed, 416 insertions(+), 296 deletions(-) delete mode 100755 benchmark/gsm8k/download_data.sh create mode 100644 python/sglang/test/few_shot_gsm8k.py diff --git a/benchmark/gsm8k/README.md b/benchmark/gsm8k/README.md index a7dc04d9a..c110f533c 100644 --- a/benchmark/gsm8k/README.md +++ b/benchmark/gsm8k/README.md @@ -1,8 +1,3 @@ -## Download data -``` -bash download_data.sh -``` - ## Run benchmark ### Benchmark sglang diff --git a/benchmark/gsm8k/bench_other.py b/benchmark/gsm8k/bench_other.py index 2a938d6bb..a8bbcfb5c 100644 --- a/benchmark/gsm8k/bench_other.py +++ b/benchmark/gsm8k/bench_other.py @@ -10,7 +10,7 @@ import numpy as np from tqdm import tqdm from sglang.test.test_utils import add_common_other_args_and_parse, get_call_generate -from sglang.utils import dump_state_text, read_jsonl +from sglang.utils import download_and_cache_file, dump_state_text, read_jsonl INVALID = -9999999 @@ -41,24 +41,28 @@ def get_answer_value(answer_str): def main(args): - lines = read_jsonl(args.data_path) + # Select backend + call_generate = get_call_generate(args) + + # Read data + url = "https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl" + filename = download_and_cache_file(url) + lines = list(read_jsonl(filename)) # Construct prompts - k = args.num_shot - few_shot_examples = get_few_shot_examples(lines, k) + num_questions = args.num_questions + num_shots = args.num_shots + few_shot_examples = get_few_shot_examples(lines, num_shots) questions = [] labels = [] - for i in range(len(lines[: args.num_questions])): + for i in range(len(lines[:num_questions])): questions.append(get_one_example(lines, i, False)) labels.append(get_answer_value(lines[i]["answer"])) assert all(l != INVALID for l in labels) states = [None] * len(labels) - # Select backend - call_generate = get_call_generate(args) - # Run requests if args.backend != "lmql": # Use thread pool @@ -113,11 +117,13 @@ def main(args): # Compute accuracy acc = np.mean(np.array(preds) == np.array(labels)) invalid = np.mean(np.array(preds) == INVALID) - print(f"Latency: {latency:.3f}") - print(f"Invalid: {invalid:.3f}") - print(f"Accuracy: {acc:.3f}") - # Write results + # Print results + print(f"Accuracy: {acc:.3f}") + print(f"Invalid: {invalid:.3f}") + print(f"Latency: {latency:.3f} s") + + # Dump results dump_state_text(f"tmp_output_{args.backend}.txt", states) with open(args.result_file, "a") as fout: @@ -138,7 +144,7 @@ def main(args): if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("--num-shot", type=int, default=5) + parser.add_argument("--num-shots", type=int, default=5) parser.add_argument("--data-path", type=str, default="test.jsonl") parser.add_argument("--num-questions", type=int, default=200) args = add_common_other_args_and_parse(parser) diff --git a/benchmark/gsm8k/bench_sglang.py b/benchmark/gsm8k/bench_sglang.py index d32790fe0..9fe9b79ba 100644 --- a/benchmark/gsm8k/bench_sglang.py +++ b/benchmark/gsm8k/bench_sglang.py @@ -6,11 +6,12 @@ import time import numpy as np +from sglang.api import set_default_backend from sglang.test.test_utils import ( add_common_sglang_args_and_parse, select_sglang_backend, ) -from sglang.utils import dump_state_text, read_jsonl +from sglang.utils import download_and_cache_file, dump_state_text, read_jsonl INVALID = -9999999 @@ -41,15 +42,22 @@ def get_answer_value(answer_str): def main(args): - lines = read_jsonl(args.data_path) + # Select backend + set_default_backend(select_sglang_backend(args)) + + # Read data + url = "https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl" + filename = download_and_cache_file(url) + lines = list(read_jsonl(filename)) # Construct prompts - k = args.num_shot - few_shot_examples = get_few_shot_examples(lines, k) + num_questions = args.num_questions + num_shots = args.num_shots + few_shot_examples = get_few_shot_examples(lines, num_shots) questions = [] labels = [] - for i in range(len(lines[: args.num_questions])): + for i in range(len(lines[:num_questions])): questions.append(get_one_example(lines, i, False)) labels.append(get_answer_value(lines[i]["answer"])) assert all(l != INVALID for l in labels) @@ -72,15 +80,11 @@ def main(args): ########## SGL Program End ########## ##################################### - # Select backend - backend = select_sglang_backend(args) - # Run requests tic = time.time() states = few_shot_gsm8k.run_batch( arguments, temperature=0, - backend=backend, num_threads=args.parallel, progress_bar=True, ) @@ -96,11 +100,20 @@ def main(args): # Compute accuracy acc = np.mean(np.array(preds) == np.array(labels)) invalid = np.mean(np.array(preds) == INVALID) - print(f"Latency: {latency:.3f}") - print(f"Invalid: {invalid:.3f}") - print(f"Accuracy: {acc:.3f}") - # Write results + # Compute speed + num_output_tokens = sum( + s.get_meta_info("answer")["completion_tokens"] for s in states + ) + output_throughput = num_output_tokens / latency + + # Print results + print(f"Accuracy: {acc:.3f}") + print(f"Invalid: {invalid:.3f}") + print(f"Latency: {latency:.3f} s") + print(f"Output throughput: {output_throughput:.3f} token/s") + + # Dump results dump_state_text(f"tmp_output_{args.backend}.txt", states) with open(args.result_file, "a") as fout: @@ -121,7 +134,7 @@ def main(args): if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("--num-shot", type=int, default=5) + parser.add_argument("--num-shots", type=int, default=5) parser.add_argument("--data-path", type=str, default="test.jsonl") parser.add_argument("--num-questions", type=int, default=200) args = add_common_sglang_args_and_parse(parser) diff --git a/benchmark/gsm8k/download_data.sh b/benchmark/gsm8k/download_data.sh deleted file mode 100755 index a9aa7756d..000000000 --- a/benchmark/gsm8k/download_data.sh +++ /dev/null @@ -1,2 +0,0 @@ -wget https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/train.jsonl -wget https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl \ No newline at end of file diff --git a/benchmark/hellaswag/README.md b/benchmark/hellaswag/README.md index b3e7abc30..cb7e65366 100644 --- a/benchmark/hellaswag/README.md +++ b/benchmark/hellaswag/README.md @@ -1,8 +1,3 @@ -## Download data -``` -wget https://raw.githubusercontent.com/rowanz/hellaswag/master/data/hellaswag_val.jsonl -``` - ## Run benchmark ### Benchmark sglang diff --git a/benchmark/hellaswag/bench_other.py b/benchmark/hellaswag/bench_other.py index 5b9ba797b..04be4569a 100644 --- a/benchmark/hellaswag/bench_other.py +++ b/benchmark/hellaswag/bench_other.py @@ -8,7 +8,7 @@ import numpy as np from tqdm import tqdm from sglang.test.test_utils import add_common_other_args_and_parse, get_call_select -from sglang.utils import read_jsonl +from sglang.utils import download_and_cache_file, read_jsonl def get_one_example(lines, i, include_answer): @@ -26,25 +26,29 @@ def get_few_shot_examples(lines, k): def main(args): - lines = read_jsonl(args.data_path) + # Select backend + call_select = get_call_select(args) + + # Read data + url = "https://raw.githubusercontent.com/rowanz/hellaswag/master/data/hellaswag_val.jsonl" + filename = download_and_cache_file(url) + lines = list(read_jsonl(filename)) # Construct prompts - k = args.num_shot - few_shot_examples = get_few_shot_examples(lines, k) + num_questions = args.num_questions + num_shots = args.num_shots + few_shot_examples = get_few_shot_examples(lines, num_shots) questions = [] choices = [] labels = [] - for i in range(len(lines[: args.num_questions])): + for i in range(len(lines[:num_questions])): questions.append(get_one_example(lines, i, False)) choices.append(lines[i]["endings"]) labels.append(lines[i]["label"]) preds = [None] * len(labels) - # Select backend - call_select = get_call_select(args) - # Run requests if args.backend != "lmql": # Use thread pool @@ -65,7 +69,6 @@ def main(args): total=len(questions), ) ) - else: # Use asyncio async def batched_call(batch_size): @@ -108,7 +111,7 @@ def main(args): if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("--num-shot", type=int, default=20) + parser.add_argument("--num-shots", type=int, default=20) parser.add_argument("--data-path", type=str, default="hellaswag_val.jsonl") parser.add_argument("--num-questions", type=int, default=200) args = add_common_other_args_and_parse(parser) diff --git a/benchmark/hellaswag/bench_sglang.py b/benchmark/hellaswag/bench_sglang.py index 2ccf1aaee..f09d7256d 100644 --- a/benchmark/hellaswag/bench_sglang.py +++ b/benchmark/hellaswag/bench_sglang.py @@ -4,11 +4,12 @@ import time import numpy as np +from sglang.api import set_default_backend from sglang.test.test_utils import ( add_common_sglang_args_and_parse, select_sglang_backend, ) -from sglang.utils import read_jsonl +from sglang.utils import download_and_cache_file, read_jsonl def get_one_example(lines, i, include_answer): @@ -26,16 +27,23 @@ def get_few_shot_examples(lines, k): def main(args): - lines = read_jsonl(args.data_path) + # Select backend + set_default_backend(select_sglang_backend(args)) + + # Read data + url = "https://raw.githubusercontent.com/rowanz/hellaswag/master/data/hellaswag_val.jsonl" + filename = download_and_cache_file(url) + lines = list(read_jsonl(filename)) # Construct prompts - k = args.num_shot - few_shot_examples = get_few_shot_examples(lines, k) + num_questions = args.num_questions + num_shots = args.num_shots + few_shot_examples = get_few_shot_examples(lines, num_shots) questions = [] choices = [] labels = [] - for i in range(len(lines[: args.num_questions])): + for i in range(len(lines[:num_questions])): questions.append(get_one_example(lines, i, False)) choices.append(lines[i]["endings"]) labels.append(lines[i]["label"]) @@ -56,15 +64,11 @@ def main(args): ########## SGL Program End ########## ##################################### - # Select backend - backend = select_sglang_backend(args) - # Run requests tic = time.time() rets = few_shot_hellaswag.run_batch( arguments, temperature=0, - backend=backend, num_threads=args.parallel, progress_bar=True, ) @@ -95,7 +99,7 @@ def main(args): if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("--num-shot", type=int, default=20) + parser.add_argument("--num-shots", type=int, default=20) parser.add_argument("--data-path", type=str, default="hellaswag_val.jsonl") parser.add_argument("--num-questions", type=int, default=200) args = add_common_sglang_args_and_parse(parser) diff --git a/examples/frontend_language/usage/llava_video/srt_example_llava_v.py b/examples/frontend_language/usage/llava_video/srt_example_llava_v.py index 02bab342a..c3b8da7d6 100644 --- a/examples/frontend_language/usage/llava_video/srt_example_llava_v.py +++ b/examples/frontend_language/usage/llava_video/srt_example_llava_v.py @@ -7,6 +7,7 @@ python3 srt_example_llava_v.py import argparse import csv +import json import os import time @@ -223,7 +224,7 @@ if __name__ == "__main__": tokenizer_path=tokenizer_path, port=cur_port, additional_ports=[cur_port + 1, cur_port + 2, cur_port + 3, cur_port + 4], - model_override_args=model_override_args, + json_model_override_args=json.dumps(model_override_args), tp_size=1, ) sgl.set_default_backend(runtime) diff --git a/python/sglang/bench_serving.py b/python/sglang/bench_serving.py index 69d175d84..d51aee4ec 100644 --- a/python/sglang/bench_serving.py +++ b/python/sglang/bench_serving.py @@ -298,34 +298,41 @@ class BenchmarkMetrics: median_e2e_latency_ms: float -default_sharegpt_path = "ShareGPT_V3_unfiltered_cleaned_split.json" +SHAREGPT_URL = "https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json" -def download_sharegpt_dataset(path): - url = "https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json" +def download_and_cache_file(url: str, filename: Optional[str] = None): + """Read and cache a file from a url.""" + if filename is None: + filename = os.path.join("/tmp", url.split("/")[-1]) - print(f"Downloading dataset from {url}") - try: - response = requests.get(url, stream=True) - response.raise_for_status() + # Check if the cache file already exists + if os.path.exists(filename): + return filename - total_size = int(response.headers.get("content-length", 0)) - block_size = 8192 + print(f"Downloading from {url} to {filename}") - with open(path, "wb") as f, tqdm( - desc="Downloading", - total=total_size, - unit="iB", - unit_scale=True, - unit_divisor=1024, - ) as progress_bar: - for data in response.iter_content(block_size): - size = f.write(data) - progress_bar.update(size) + # Stream the response to show the progress bar + response = requests.get(url, stream=True) + response.raise_for_status() # Check for request errors - print(f"Dataset downloaded and saved to {path}") - except requests.RequestException as e: - raise Exception(f"Failed to download dataset: {e}") + # Total size of the file in bytes + total_size = int(response.headers.get("content-length", 0)) + chunk_size = 1024 # Download in chunks of 1KB + + # Use tqdm to display the progress bar + with open(filename, "wb") as f, tqdm( + desc=filename, + total=total_size, + unit="B", + unit_scale=True, + unit_divisor=1024, + ) as bar: + for chunk in response.iter_content(chunk_size=chunk_size): + f.write(chunk) + bar.update(len(chunk)) + + return filename def sample_sharegpt_requests( @@ -338,13 +345,8 @@ def sample_sharegpt_requests( raise ValueError("output_len too small") # Download sharegpt if necessary - if not os.path.isfile(dataset_path) and not os.path.isfile(default_sharegpt_path): - download_sharegpt_dataset(default_sharegpt_path) - dataset_path = default_sharegpt_path - else: - dataset_path = ( - dataset_path if os.path.isfile(dataset_path) else default_sharegpt_path - ) + if not os.path.isfile(dataset_path): + dataset_path = download_and_cache_file(SHAREGPT_URL) # Load the dataset. with open(dataset_path) as f: @@ -412,15 +414,8 @@ def sample_random_requests( # Sample token ids from ShareGPT and repeat/truncate them to satisfy the input_lens # Download sharegpt if necessary - if not os.path.isfile(dataset_path) and not os.path.isfile( - default_sharegpt_path - ): - download_sharegpt_dataset(default_sharegpt_path) - dataset_path = default_sharegpt_path - else: - dataset_path = ( - dataset_path if os.path.isfile(dataset_path) else default_sharegpt_path - ) + if not os.path.isfile(dataset_path): + dataset_path = download_and_cache_file(SHAREGPT_URL) # Load the dataset. with open(dataset_path) as f: diff --git a/python/sglang/launch_server.py b/python/sglang/launch_server.py index 06aa140d9..ce4cb07c2 100644 --- a/python/sglang/launch_server.py +++ b/python/sglang/launch_server.py @@ -9,10 +9,9 @@ from sglang.srt.utils import kill_child_process if __name__ == "__main__": server_args = prepare_server_args(sys.argv[1:]) - model_override_args = server_args.json_model_override_args try: - launch_server(server_args, model_override_args=model_override_args) + launch_server(server_args) except Exception as e: raise e finally: diff --git a/python/sglang/launch_server_llavavid.py b/python/sglang/launch_server_llavavid.py index 6b8d151ee..6816dcc11 100644 --- a/python/sglang/launch_server_llavavid.py +++ b/python/sglang/launch_server_llavavid.py @@ -1,5 +1,6 @@ """Launch the inference server for Llava-video model.""" +import json import sys from sglang.srt.server import launch_server, prepare_server_args @@ -19,5 +20,6 @@ if __name__ == "__main__": model_override_args["model_max_length"] = 4096 * 2 if "34b" in server_args.model_path.lower(): model_override_args["image_token_index"] = 64002 + server_args.json_model_override_args = json.dumps(model_override_args) - launch_server(server_args, model_override_args, None) + launch_server(server_args) diff --git a/python/sglang/srt/constrained/fsm_cache.py b/python/sglang/srt/constrained/fsm_cache.py index 57c491306..fd5995dad 100644 --- a/python/sglang/srt/constrained/fsm_cache.py +++ b/python/sglang/srt/constrained/fsm_cache.py @@ -16,6 +16,7 @@ limitations under the License. """Cache for the compressed finite state machine.""" from outlines.fsm.json_schema import build_regex_from_schema +from transformers import AutoTokenizer from sglang.srt.constrained import RegexGuide, TransformerTokenizer from sglang.srt.constrained.base_tool_cache import BaseToolCache @@ -28,12 +29,9 @@ class FSMCache(BaseToolCache): tokenizer_args_dict, enable=True, skip_tokenizer_init=False, - json_schema_mode=False, ): super().__init__(enable=enable) - self.json_schema_mode = json_schema_mode - if ( skip_tokenizer_init or tokenizer_path.endswith(".json") @@ -42,44 +40,37 @@ class FSMCache(BaseToolCache): # Do not support TiktokenTokenizer or SentencePieceTokenizer return - from importlib.metadata import version + tokenizer_args_dict.setdefault("padding_side", "left") + tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, **tokenizer_args_dict) + try: + self.outlines_tokenizer = TransformerTokenizer(tokenizer) + except AttributeError: + # FIXME: tmp fix for chatglm2 & chatglm3 (pad_token_id=0) + origin_pad_token_id = tokenizer.pad_token_id - if version("outlines") >= "0.0.35": - from transformers import AutoTokenizer + def fset(self, value): + self._value = value - tokenizer_args_dict.setdefault("padding_side", "left") - tokenizer = AutoTokenizer.from_pretrained( - tokenizer_path, **tokenizer_args_dict + type(tokenizer).pad_token_id = property( + fget=type(tokenizer).pad_token_id.fget, fset=fset ) - try: - self.outlines_tokenizer = TransformerTokenizer(tokenizer) - except AttributeError: - # FIXME: tmp fix for chatglm2 & chatglm3 (pad_token_id=0) - origin_pad_token_id = tokenizer.pad_token_id - - def fset(self, value): - self._value = value - - type(tokenizer).pad_token_id = property( - fget=type(tokenizer).pad_token_id.fget, fset=fset - ) - self.outlines_tokenizer = TransformerTokenizer(tokenizer) - self.outlines_tokenizer.tokenizer.pad_token_id = origin_pad_token_id - self.outlines_tokenizer.pad_token_id = origin_pad_token_id - self.outlines_tokenizer.pad_token = ( - self.outlines_tokenizer.tokenizer.pad_token - ) - self.outlines_tokenizer.vocabulary = ( - self.outlines_tokenizer.tokenizer.get_vocab() - ) - else: - self.outlines_tokenizer = TransformerTokenizer( - tokenizer_path, **tokenizer_args_dict + self.outlines_tokenizer = TransformerTokenizer(tokenizer) + self.outlines_tokenizer.tokenizer.pad_token_id = origin_pad_token_id + self.outlines_tokenizer.pad_token_id = origin_pad_token_id + self.outlines_tokenizer.pad_token = ( + self.outlines_tokenizer.tokenizer.pad_token + ) + self.outlines_tokenizer.vocabulary = ( + self.outlines_tokenizer.tokenizer.get_vocab() ) - def init_value(self, value): - if self.json_schema_mode: - regex = build_regex_from_schema(value, whitespace_pattern=r"[\n\t ]*") - return RegexGuide(regex, self.outlines_tokenizer), regex + def init_value(self, key): + key_type, key_string = key + if key_type == "json": + regex = build_regex_from_schema(key_string, whitespace_pattern=r"[\n\t ]*") + elif key_type == "regex": + regex = key_string else: - return RegexGuide(value, self.outlines_tokenizer) + raise ValueError(f"Invalid key_type: {key_type}") + + return RegexGuide(regex, self.outlines_tokenizer), regex diff --git a/python/sglang/srt/managers/controller_multi.py b/python/sglang/srt/managers/controller_multi.py index ba626d4cf..e4b316155 100644 --- a/python/sglang/srt/managers/controller_multi.py +++ b/python/sglang/srt/managers/controller_multi.py @@ -71,12 +71,10 @@ class ControllerMulti: self, server_args: ServerArgs, port_args: PortArgs, - model_override_args, ): # Parse args self.server_args = server_args self.port_args = port_args - self.model_override_args = model_override_args self.load_balance_method = LoadBalanceMethod.from_str( server_args.load_balance_method ) @@ -114,7 +112,6 @@ class ControllerMulti: self.server_args, self.port_args, pipe_controller_writer, - self.model_override_args, True, gpu_ids, dp_worker_id, @@ -189,14 +186,13 @@ def start_controller_process( server_args: ServerArgs, port_args: PortArgs, pipe_writer, - model_override_args: dict, ): """Start a controller process.""" configure_logger(server_args) try: - controller = ControllerMulti(server_args, port_args, model_override_args) + controller = ControllerMulti(server_args, port_args) except Exception: pipe_writer.send(get_exception_traceback()) raise diff --git a/python/sglang/srt/managers/controller_single.py b/python/sglang/srt/managers/controller_single.py index 2ae37059c..fe03ca1d4 100644 --- a/python/sglang/srt/managers/controller_single.py +++ b/python/sglang/srt/managers/controller_single.py @@ -40,7 +40,6 @@ class ControllerSingle: self, server_args: ServerArgs, port_args: PortArgs, - model_override_args: dict, gpu_ids: List[int], is_data_parallel_worker: bool, dp_worker_id: int, @@ -76,7 +75,6 @@ class ControllerSingle: tp_rank_range, server_args, port_args.nccl_ports[dp_worker_id], - model_override_args, ) # Launch tp rank 0 @@ -85,7 +83,6 @@ class ControllerSingle: 0, server_args, port_args.nccl_ports[dp_worker_id], - model_override_args, ) self.tp_cpu_group = self.tp_server.model_runner.tp_group.cpu_group @@ -126,7 +123,6 @@ def start_controller_process( server_args: ServerArgs, port_args: PortArgs, pipe_writer: multiprocessing.connection.Connection, - model_override_args: dict, is_data_parallel_worker: bool = False, gpu_ids: List[int] = None, dp_worker_id: int = None, @@ -149,7 +145,6 @@ def start_controller_process( controller = ControllerSingle( server_args, port_args, - model_override_args, gpu_ids, is_data_parallel_worker, dp_worker_id, diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index d0cfed08c..d2fa67601 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -18,6 +18,7 @@ limitations under the License. import asyncio import concurrent.futures import dataclasses +import json import logging import multiprocessing as mp import os @@ -77,7 +78,6 @@ class TokenizerManager: self, server_args: ServerArgs, port_args: PortArgs, - model_override_args: dict = None, ): self.server_args = server_args @@ -95,7 +95,7 @@ class TokenizerManager: self.hf_config = get_config( self.model_path, trust_remote_code=server_args.trust_remote_code, - model_override_args=model_override_args, + model_override_args=json.loads(server_args.json_model_override_args), ) self.is_generation = is_generation_model( self.hf_config.architectures, self.server_args.is_embedding diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py index 7bb9c4335..513bc517f 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -15,13 +15,14 @@ limitations under the License. """A tensor parallel worker.""" +import json import logging import multiprocessing import os import pickle import time import warnings -from typing import Any, List, Optional, Union +from typing import Any, List, Optional import torch import torch.distributed @@ -66,6 +67,7 @@ from sglang.utils import get_exception_traceback logger = logging.getLogger(__name__) +# Crash on warning if we are running CI tests crash_on_warning = os.getenv("SGLANG_IS_IN_CI", "false") == "true" @@ -76,11 +78,10 @@ class ModelTpServer: tp_rank: int, server_args: ServerArgs, nccl_port: int, - model_override_args: dict, ): suppress_other_loggers() - # Copy arguments + # Parse arguments self.gpu_id = gpu_id self.tp_rank = tp_rank self.tp_size = server_args.tp_size @@ -93,9 +94,8 @@ class ModelTpServer: server_args.model_path, server_args.trust_remote_code, context_length=server_args.context_length, - model_override_args=model_override_args, + model_override_args=json.loads(server_args.json_model_override_args), ) - self.model_runner = ModelRunner( model_config=self.model_config, mem_fraction_static=server_args.mem_fraction_static, @@ -136,7 +136,7 @@ class ModelTpServer: self.max_total_num_tokens - 1, ) - # Sync random seed + # Sync random seed across TP workers server_args.random_seed = broadcast_recv_input( [server_args.random_seed], self.tp_rank, @@ -144,7 +144,7 @@ class ModelTpServer: )[0] set_random_seed(server_args.random_seed) - # Print info + # Print debug info logger.info( f"max_total_num_tokens={self.max_total_num_tokens}, " f"max_prefill_tokens={self.max_prefill_tokens}, " @@ -181,7 +181,7 @@ class ModelTpServer: self.num_generated_tokens = 0 self.last_stats_tic = time.time() - # Chunked prefill + # Init chunked prefill self.chunked_prefill_size = server_args.chunked_prefill_size self.current_inflight_req = None self.is_mixed_chunk = ( @@ -197,16 +197,6 @@ class ModelTpServer: "trust_remote_code": server_args.trust_remote_code, }, skip_tokenizer_init=server_args.skip_tokenizer_init, - json_schema_mode=False, - ) - self.json_fsm_cache = FSMCache( - server_args.tokenizer_path, - { - "tokenizer_mode": server_args.tokenizer_mode, - "trust_remote_code": server_args.trust_remote_code, - }, - skip_tokenizer_init=server_args.skip_tokenizer_init, - json_schema_mode=True, ) self.jump_forward_cache = JumpForwardCache() @@ -227,11 +217,12 @@ class ModelTpServer: try: # Recv requests for recv_req in recv_reqs: - if isinstance( - recv_req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput) - ): + if isinstance(recv_req, TokenizedGenerateReqInput): self.handle_generate_request(recv_req) self.do_not_get_new_batch = False + elif isinstance(recv_req, TokenizedEmbeddingReqInput): + self.handle_embedding_request(recv_req) + self.do_not_get_new_batch = False elif isinstance(recv_req, FlushCacheReq): self.flush_cache() elif isinstance(recv_req, AbortReq): @@ -331,57 +322,56 @@ class ModelTpServer: def handle_generate_request( self, - recv_req: Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput], + recv_req: TokenizedGenerateReqInput, ): req = Req(recv_req.rid, recv_req.input_text, recv_req.input_ids) req.tokenizer = self.tokenizer req.sampling_params = recv_req.sampling_params - if self.model_runner.is_generation: - req.pixel_values = recv_req.pixel_values - if req.pixel_values is not None: - # Use image hash as fake token_ids, which is then used - # for prefix matching - image_hash = hash(tuple(recv_req.image_hashes)) - req.pad_value = [ - (image_hash) % self.model_config.vocab_size, - (image_hash >> 16) % self.model_config.vocab_size, - (image_hash >> 32) % self.model_config.vocab_size, - (image_hash >> 64) % self.model_config.vocab_size, - ] - req.image_sizes = recv_req.image_sizes - ( - req.origin_input_ids, - req.image_offsets, - ) = self.model_runner.model.pad_input_ids( - req.origin_input_ids_unpadded, - req.pad_value, - req.pixel_values, - req.image_sizes, - ) - # Only when pixel values is not None we have modalities - req.modalities = recv_req.modalites - req.return_logprob = recv_req.return_logprob - req.logprob_start_len = recv_req.logprob_start_len - req.top_logprobs_num = recv_req.top_logprobs_num - req.stream = recv_req.stream + req.pixel_values = recv_req.pixel_values + if req.pixel_values is not None: + # Use image hash as fake token_ids, which is then used + # for prefix matching + image_hash = hash(tuple(recv_req.image_hashes)) + req.pad_value = [ + (image_hash) % self.model_config.vocab_size, + (image_hash >> 16) % self.model_config.vocab_size, + (image_hash >> 32) % self.model_config.vocab_size, + (image_hash >> 64) % self.model_config.vocab_size, + ] + req.image_sizes = recv_req.image_sizes + ( + req.origin_input_ids, + req.image_offsets, + ) = self.model_runner.model.pad_input_ids( + req.origin_input_ids_unpadded, + req.pad_value, + req.pixel_values, + req.image_sizes, + ) + # Only when pixel values is not None we have modalities + req.modalities = recv_req.modalites + req.return_logprob = recv_req.return_logprob + req.logprob_start_len = recv_req.logprob_start_len + req.top_logprobs_num = recv_req.top_logprobs_num + req.stream = recv_req.stream - # Init regex fsm fron json + # Init regex FSM + if ( + req.sampling_params.json_schema is not None + or req.sampling_params.regex is not None + ): if req.sampling_params.json_schema is not None: - req.regex_fsm, computed_regex_string = self.json_fsm_cache.query( - req.sampling_params.json_schema + req.regex_fsm, computed_regex_string = self.regex_fsm_cache.query( + ("json", req.sampling_params.json_schema) ) - if not self.disable_regex_jump_forward: - req.jump_forward_map = self.jump_forward_cache.query( - computed_regex_string - ) - - # Init regex fsm elif req.sampling_params.regex is not None: - req.regex_fsm = self.regex_fsm_cache.query(req.sampling_params.regex) - if not self.disable_regex_jump_forward: - req.jump_forward_map = self.jump_forward_cache.query( - req.sampling_params.regex - ) + req.regex_fsm, computed_regex_string = self.regex_fsm_cache.query( + ("regex", req.sampling_params.regex) + ) + if not self.disable_regex_jump_forward: + req.jump_forward_map = self.jump_forward_cache.query( + computed_regex_string + ) # Truncate prompts that are too long if len(req.origin_input_ids) >= self.max_req_input_len: @@ -390,16 +380,32 @@ class ModelTpServer: "the max context length. Truncated!!!" ) req.origin_input_ids = req.origin_input_ids[: self.max_req_input_len] + req.sampling_params.max_new_tokens = min( + ( + req.sampling_params.max_new_tokens + if req.sampling_params.max_new_tokens is not None + else 1 << 30 + ), + self.max_req_input_len - 1 - len(req.origin_input_ids), + ) - if self.model_runner.is_generation: - req.sampling_params.max_new_tokens = min( - ( - req.sampling_params.max_new_tokens - if req.sampling_params.max_new_tokens is not None - else 1 << 30 - ), - self.max_req_input_len - 1 - len(req.origin_input_ids), + self.waiting_queue.append(req) + + def handle_embedding_request( + self, + recv_req: TokenizedEmbeddingReqInput, + ): + req = Req(recv_req.rid, recv_req.input_text, recv_req.input_ids) + req.tokenizer = self.tokenizer + req.sampling_params = recv_req.sampling_params + + # Truncate prompts that are too long + if len(req.origin_input_ids) >= self.max_req_input_len: + logger.warn( + "Request length is longer than the KV cache pool size or " + "the max context length. Truncated!!!" ) + req.origin_input_ids = req.origin_input_ids[: self.max_req_input_len] self.waiting_queue.append(req) @@ -892,7 +898,6 @@ def run_tp_server( tp_rank: int, server_args: ServerArgs, nccl_port: int, - model_override_args: dict, ): """Run a tensor parallel model server.""" configure_logger(server_args, prefix=f" TP{tp_rank}") @@ -903,7 +908,6 @@ def run_tp_server( tp_rank, server_args, nccl_port, - model_override_args, ) tp_cpu_group = model_server.model_runner.tp_group.cpu_group @@ -920,14 +924,13 @@ def launch_tp_servers( tp_rank_range: List[int], server_args: ServerArgs, nccl_port: int, - model_override_args: dict, ): """Launch multiple tensor parallel servers.""" procs = [] for i in tp_rank_range: proc = multiprocessing.Process( target=run_tp_server, - args=(gpu_ids[i], i, server_args, nccl_port, model_override_args), + args=(gpu_ids[i], i, server_args, nccl_port), ) proc.start() procs.append(proc) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 3d3e0cde9..9c82b2a81 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -18,6 +18,7 @@ limitations under the License. import gc import importlib import importlib.resources +import json import logging import pkgutil from functools import lru_cache diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py index feaf91dd3..d44d61752 100644 --- a/python/sglang/srt/server.py +++ b/python/sglang/srt/server.py @@ -272,7 +272,6 @@ async def retrieve_file_content(file_id: str): def launch_server( server_args: ServerArgs, - model_override_args: Optional[dict] = None, pipe_finish_writer: Optional[mp.connection.Connection] = None, ): """Launch an HTTP server.""" @@ -317,7 +316,6 @@ def launch_server( tp_rank_range, server_args, ports[3], - model_override_args, ) try: @@ -328,7 +326,7 @@ def launch_server( return # Launch processes - tokenizer_manager = TokenizerManager(server_args, port_args, model_override_args) + tokenizer_manager = TokenizerManager(server_args, port_args) if server_args.chat_template: load_chat_template_for_openai_api(tokenizer_manager, server_args.chat_template) pipe_controller_reader, pipe_controller_writer = mp.Pipe(duplex=False) @@ -341,7 +339,7 @@ def launch_server( proc_controller = mp.Process( target=start_controller_process, - args=(server_args, port_args, pipe_controller_writer, model_override_args), + args=(server_args, port_args, pipe_controller_writer), ) proc_controller.start() @@ -501,7 +499,6 @@ class Runtime: def __init__( self, log_level: str = "error", - model_override_args: Optional[dict] = None, *args, **kwargs, ): @@ -525,7 +522,7 @@ class Runtime: proc = mp.Process( target=launch_server, - args=(self.server_args, model_override_args, pipe_writer), + args=(self.server_args, pipe_writer), ) proc.start() pipe_writer.close() diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index e21f02108..14dd63b5a 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -76,6 +76,14 @@ class ServerArgs: dp_size: int = 1 load_balance_method: str = "round_robin" + # Distributed args + nccl_init_addr: Optional[str] = None + nnodes: int = 1 + node_rank: Optional[int] = None + + # Model override args in JSON + json_model_override_args: str = "{}" + # Optimization/debug options disable_flashinfer: bool = False disable_flashinfer_sampling: bool = False @@ -91,14 +99,6 @@ class ServerArgs: enable_mla: bool = False triton_attention_reduce_in_fp32: bool = False - # Distributed args - nccl_init_addr: Optional[str] = None - nnodes: int = 1 - node_rank: Optional[int] = None - - # Model override args in JSON - json_model_override_args: Optional[dict] = None - def __post_init__(self): if self.tokenizer_path is None: self.tokenizer_path = self.model_path @@ -385,6 +385,14 @@ class ServerArgs: ) parser.add_argument("--node-rank", type=int, help="The node rank.") + # Model override args + parser.add_argument( + "--json-model-override-args", + type=str, + help="A dictionary in JSON string format used to override default model configurations.", + default=ServerArgs.json_model_override_args, + ) + # Optimization/debug options parser.add_argument( "--disable-flashinfer", @@ -459,22 +467,10 @@ class ServerArgs: help="Turn on memory efficient weight loading with quantization (quantize per layer during loading).", ) - # Model override args - parser.add_argument( - "--json-model-override-args", - type=str, - help="A dictionary in JSON string format used to override default model configurations.", - ) - @classmethod def from_cli_args(cls, args: argparse.Namespace): args.tp_size = args.tensor_parallel_size args.dp_size = args.data_parallel_size - args.json_model_override_args = ( - json.loads(args.json_model_override_args) - if args.json_model_override_args - else None - ) attrs = [attr.name for attr in dataclasses.fields(cls)] return cls(**{attr: getattr(args, attr) for attr in attrs}) @@ -498,7 +494,7 @@ class ServerArgs: self.disable_flashinfer = False -def prepare_server_args(args: argparse.Namespace) -> ServerArgs: +def prepare_server_args(argv: List[str]) -> ServerArgs: """ Prepare the server arguments from the command line arguments. @@ -511,7 +507,7 @@ def prepare_server_args(args: argparse.Namespace) -> ServerArgs: """ parser = argparse.ArgumentParser() ServerArgs.add_cli_args(parser) - raw_args = parser.parse_args(args) + raw_args = parser.parse_args(argv) server_args = ServerArgs.from_cli_args(raw_args) return server_args diff --git a/python/sglang/test/few_shot_gsm8k.py b/python/sglang/test/few_shot_gsm8k.py new file mode 100644 index 000000000..18ae2d8c3 --- /dev/null +++ b/python/sglang/test/few_shot_gsm8k.py @@ -0,0 +1,132 @@ +""" +Run few-shot GSM-8K evaluation. + +Usage: +python3 -m sglang.test.few_shot_gsm8k --num-questions 200 +""" + +import argparse +import ast +import re +import time + +import numpy as np + +from sglang.api import set_default_backend +from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint +from sglang.utils import download_and_cache_file, dump_state_text, read_jsonl + +INVALID = -9999999 + + +def get_one_example(lines, i, include_answer): + ret = "Question: " + lines[i]["question"] + "\nAnswer:" + if include_answer: + ret += " " + lines[i]["answer"] + return ret + + +def get_few_shot_examples(lines, k): + ret = "" + for i in range(k): + ret += get_one_example(lines, i, True) + "\n\n" + return ret + + +def get_answer_value(answer_str): + answer_str = answer_str.replace(",", "") + numbers = re.findall(r"\d+", answer_str) + if len(numbers) < 1: + return INVALID + try: + return ast.literal_eval(numbers[-1]) + except SyntaxError: + return INVALID + + +def main(args): + # Select backend + set_default_backend(RuntimeEndpoint(f"{args.host}:{args.port}")) + + # Read data + url = "https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl" + filename = download_and_cache_file(url) + lines = list(read_jsonl(filename)) + + # Construct prompts + num_questions = args.num_questions + num_shots = args.num_shots + few_shot_examples = get_few_shot_examples(lines, num_shots) + + questions = [] + labels = [] + for i in range(len(lines[:num_questions])): + questions.append(get_one_example(lines, i, False)) + labels.append(get_answer_value(lines[i]["answer"])) + assert all(l != INVALID for l in labels) + arguments = [{"question": q} for q in questions] + + ##################################### + ######### SGL Program Begin ######### + ##################################### + + import sglang as sgl + + @sgl.function + def few_shot_gsm8k(s, question): + s += few_shot_examples + question + s += sgl.gen( + "answer", max_tokens=512, stop=["Question", "Assistant:", "<|separator|>"] + ) + + ##################################### + ########## SGL Program End ########## + ##################################### + + # Run requests + tic = time.time() + states = few_shot_gsm8k.run_batch( + arguments, + temperature=0, + num_threads=args.parallel, + progress_bar=True, + ) + latency = time.time() - tic + + preds = [] + for i in range(len(states)): + preds.append(get_answer_value(states[i]["answer"])) + + # print(f"{preds=}") + # print(f"{labels=}") + + # Compute accuracy + acc = np.mean(np.array(preds) == np.array(labels)) + invalid = np.mean(np.array(preds) == INVALID) + + # Compute speed + num_output_tokens = sum( + s.get_meta_info("answer")["completion_tokens"] for s in states + ) + output_throughput = num_output_tokens / latency + + # Print results + print(f"Accuracy: {acc:.3f}") + print(f"Invalid: {invalid:.3f}") + print(f"Latency: {latency:.3f} s") + print(f"Output throughput: {output_throughput:.3f} token/s") + + # Dump results + dump_state_text("tmp_output_gsm8k.txt", states) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--num-shots", type=int, default=5) + parser.add_argument("--data-path", type=str, default="test.jsonl") + parser.add_argument("--num-questions", type=int, default=200) + parser.add_argument("--parallel", type=int, default=128) + parser.add_argument("--host", type=str, default="http://127.0.0.1") + parser.add_argument("--port", type=int, default=30000) + args = parser.parse_args() + main(args) diff --git a/python/sglang/test/test_programs.py b/python/sglang/test/test_programs.py index bdecdff2f..41f466f73 100644 --- a/python/sglang/test/test_programs.py +++ b/python/sglang/test/test_programs.py @@ -7,7 +7,7 @@ import time import numpy as np import sglang as sgl -from sglang.utils import fetch_and_cache_jsonl +from sglang.utils import download_and_cache_file, read_jsonl def test_few_shot_qa(): @@ -456,10 +456,6 @@ def test_chat_completion_speculative(): def test_hellaswag_select(): """Benchmark the accuracy of sgl.select on the HellaSwag dataset.""" - url = "https://raw.githubusercontent.com/rowanz/hellaswag/master/data/hellaswag_val.jsonl" - lines = fetch_and_cache_jsonl(url) - - # Construct prompts def get_one_example(lines, i, include_answer): ret = lines[i]["activity_label"] + ": " + lines[i]["ctx"] + " " if include_answer: @@ -472,6 +468,12 @@ def test_hellaswag_select(): ret += get_one_example(lines, i, True) + "\n\n" return ret + # Read data + url = "https://raw.githubusercontent.com/rowanz/hellaswag/master/data/hellaswag_val.jsonl" + filename = download_and_cache_file(url) + lines = list(read_jsonl(filename)) + + # Construct prompts num_questions = 200 num_shots = 20 few_shot_examples = get_few_shot_examples(lines, num_shots) diff --git a/python/sglang/utils.py b/python/sglang/utils.py index b212f6caa..621efb537 100644 --- a/python/sglang/utils.py +++ b/python/sglang/utils.py @@ -12,7 +12,7 @@ import urllib.request from concurrent.futures import ThreadPoolExecutor from io import BytesIO from json import dumps -from typing import Union +from typing import Optional, Union import numpy as np import requests @@ -38,13 +38,11 @@ def is_same_type(values: list): def read_jsonl(filename: str): """Read a JSONL file.""" - rets = [] with open(filename) as fin: for line in fin: if line.startswith("#"): continue - rets.append(json.loads(line)) - return rets + yield json.loads(line) def dump_state_text(filename: str, states: list, mode: str = "w"): @@ -264,38 +262,35 @@ class LazyImport: return module(*args, **kwargs) -def fetch_and_cache_jsonl(url, cache_file="cached_data.jsonl"): - """Read and cache a jsonl file from a url.""" +def download_and_cache_file(url: str, filename: Optional[str] = None): + """Read and cache a file from a url.""" + if filename is None: + filename = os.path.join("/tmp", url.split("/")[-1]) # Check if the cache file already exists - if os.path.exists(cache_file): - print("Loading data from cache...") - with open(cache_file, "r") as f: - data = [json.loads(line) for line in f] - else: - print("Downloading data from URL...") - # Stream the response to show the progress bar - response = requests.get(url, stream=True) - response.raise_for_status() # Check for request errors + if os.path.exists(filename): + return filename - # Total size of the file in bytes - total_size = int(response.headers.get("content-length", 0)) - chunk_size = 1024 # Download in chunks of 1KB + print(f"Downloading from {url} to {filename}") - # Use tqdm to display the progress bar - with open(cache_file, "wb") as f, tqdm( - desc=cache_file, - total=total_size, - unit="B", - unit_scale=True, - unit_divisor=1024, - ) as bar: - for chunk in response.iter_content(chunk_size=chunk_size): - f.write(chunk) - bar.update(len(chunk)) + # Stream the response to show the progress bar + response = requests.get(url, stream=True) + response.raise_for_status() # Check for request errors - # Convert the data to a list of dictionaries - with open(cache_file, "r") as f: - data = [json.loads(line) for line in f] + # Total size of the file in bytes + total_size = int(response.headers.get("content-length", 0)) + chunk_size = 1024 # Download in chunks of 1KB - return data + # Use tqdm to display the progress bar + with open(filename, "wb") as f, tqdm( + desc=filename, + total=total_size, + unit="B", + unit_scale=True, + unit_divisor=1024, + ) as bar: + for chunk in response.iter_content(chunk_size=chunk_size): + f.write(chunk) + bar.update(len(chunk)) + + return filename diff --git a/test/srt/test_moe_eval_accuracy_large.py b/test/srt/test_moe_eval_accuracy_large.py index d4b1354b7..b15308dce 100644 --- a/test/srt/test_moe_eval_accuracy_large.py +++ b/test/srt/test_moe_eval_accuracy_large.py @@ -42,7 +42,7 @@ class TestEvalAccuracyLarge(unittest.TestCase): ) metrics = run_eval(args) - assert metrics["score"] >= 0.63, f"{metrics}" + assert metrics["score"] >= 0.62, f"{metrics}" def test_human_eval(self): args = SimpleNamespace( @@ -66,7 +66,7 @@ class TestEvalAccuracyLarge(unittest.TestCase): ) metrics = run_eval(args) - assert metrics["score"] >= 0.63, f"{metrics}" + assert metrics["score"] >= 0.62, f"{metrics}" if __name__ == "__main__": diff --git a/test/srt/test_server_args.py b/test/srt/test_server_args.py index 71129e3eb..d8f31ce1b 100644 --- a/test/srt/test_server_args.py +++ b/test/srt/test_server_args.py @@ -1,3 +1,4 @@ +import json import unittest from sglang.srt.server_args import prepare_server_args @@ -15,7 +16,7 @@ class TestPrepareServerArgs(unittest.TestCase): ) self.assertEqual(server_args.model_path, "model_path") self.assertEqual( - server_args.json_model_override_args, + json.loads(server_args.json_model_override_args), {"rope_scaling": {"factor": 2.0, "type": "linear"}}, )