diff --git a/3rdparty/flashinfer b/3rdparty/flashinfer index 00cf5f46f..88b9496e1 160000 --- a/3rdparty/flashinfer +++ b/3rdparty/flashinfer @@ -1 +1 @@ -Subproject commit 00cf5f46fdbb4f1dbd9277fe3b842621c1d9e7dc +Subproject commit 88b9496e1a726ddb353eb42887cfc0ab32c99460 diff --git a/benchmark/json_regex_decode/README.md b/benchmark/json_regex_decode/README.md new file mode 100644 index 000000000..853998ecf --- /dev/null +++ b/benchmark/json_regex_decode/README.md @@ -0,0 +1,61 @@ +## Run benchmark + +### Build dataset +``` +pip install wikipedia +python3 build_dataset.py +``` + +### Dependencies + +``` +llama_cpp_python 0.2.19 +guidance 0.1.10 +vllm 0.2.5 +outlines 0.0.22 +``` + +### Benchmark sglang + +Run llama-7b + +``` +python3 -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000 +``` + +Run mixtral-8x7b +(When there is a CUDA out-of-memory error, try to reduce the `--mem-fraction-static`) + +``` +python3 -m sglang.launch_server --model-path mistralai/Mixtral-8x7B-Instruct-v0.1 --port 30000 --tp-size 8 +``` + +Benchmark + +``` +python3 bench_sglang.py --num-questions 10 +``` + + +### Benchmark vllm + +Run llama-7b + +``` +python3 -m outlines.serve.serve --tokenizer-mode auto --model meta-llama/Llama-2-7b-chat-hf --disable-log-requests --port 21000 +``` + +Benchmark + +``` +python3 bench_other.py --backend vllm --num-questions 10 +``` + + +### Benchmark guidance + +Run llama-7b and benchmark + +``` +python3 bench_other.py --backend guidance --num-questions 10 --parallel 1 +``` \ No newline at end of file diff --git a/benchmark/json_regex_decode/bench_other.py b/benchmark/json_regex_decode/bench_other.py new file mode 100644 index 000000000..694979358 --- /dev/null +++ b/benchmark/json_regex_decode/bench_other.py @@ -0,0 +1,125 @@ +import argparse +import json +import time +from concurrent.futures import ThreadPoolExecutor +from functools import partial + +from sglang.test.test_utils import ( + add_common_other_args_and_parse, + call_generate_outlines, +) +from sglang.utils import dump_state_text, read_jsonl +from sglang.lang.ir import REGEX_INT, REGEX_STRING, REGEX_FLOAT +from tqdm import tqdm + +REGEX_LIST = r"\[(" + REGEX_STRING + ", )*" + REGEX_STRING + r"\]" + + +# fmt: off +def json_decode(document, generate): + s = "Please extract the information of a city from the following wikipedia page.\n" + s += "Page begin.\n" + document + "Page end.\n" + s += "Here is the name, country, and symbol of the city in JSON format.\n" + s += "{\n" + s += ' "name": ' + s += generate(s, max_tokens=8, regex=REGEX_STRING + ",") + "\n" + s += ' "country": ' + s += generate(s, max_tokens=8, regex=REGEX_STRING + ",") + "\n" + s += ' "latitude": ' + s += generate(s, max_tokens=8, regex=REGEX_FLOAT + ",") + "\n" + s += ' "population": ' + s += generate(s, max_tokens=8, regex=REGEX_INT + ",") + "\n" + s += ' "top 3 landmarks": ' + s += generate(s, max_tokens=24, regex=REGEX_LIST) + "\n" + s += "}\n" + + return s +# fmt: on + + +def main(args): + lines = read_jsonl(args.data_path) + arguments = [] + for i in range(len(lines[: args.num_questions])): + arguments.append( + { + "document": lines[i]["document"], + } + ) + states = [None] * len(arguments) + + # Select backend + if args.backend == "vllm": + url = f"{args.host}:{args.port}/generate" + generate = partial(call_generate_outlines, url=url, temperature=0) + elif args.backend == "guidance": + from guidance import gen, models + + model = models.LlamaCpp( + "/home/ubuntu/model_weights/Llama-2-7b-chat-hf/ggml-model-f16.gguf", + n_gpu_layers=-1, + n_ctx=4096, + ) + + def generate(prompt, max_tokens, stop=None, regex=None): + out = ( + model + + prompt + + gen( + name="answer", + max_tokens=max_tokens, + temperature=0, + stop=stop, + regex=regex, + ) + ) + return out["answer"] + + # warmup + for _ in range(3): + generate("Hello!" * 10, max_tokens=64, stop=None) + else: + raise ValueError(f"Invalid backend: {args.backend}") + + # Run requests + def get_one_answer(i): + states[i] = json_decode(generate=generate, **arguments[i]) + + tic = time.time() + if args.parallel == 1: + for i in tqdm(range(len(arguments))): + get_one_answer(i) + else: + with ThreadPoolExecutor(args.parallel) as executor: + rets = executor.map(get_one_answer, list(range(len(arguments)))) + for _ in rets: + pass + + latency = time.time() - tic + + # Compute accuracy + print(f"Latency: {latency:.3f}") + + # Write results + dump_state_text(f"tmp_output_{args.backend}.txt", states) + + with open(args.result_file, "a") as fout: + value = { + "task": "json_regex_decode", + "backend": args.backend, + "num_gpus": 1, + "latency": round(latency, 3), + "num_requests": args.num_questions, + "other": { + "parallel": args.parallel, + }, + } + fout.write(json.dumps(value) + "\n") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--data-path", type=str, default="questions.jsonl") + parser.add_argument("--num-questions", type=int, default=20) + args = add_common_other_args_and_parse(parser) + main(args) diff --git a/benchmark/json_regex_decode/bench_sglang.py b/benchmark/json_regex_decode/bench_sglang.py new file mode 100644 index 000000000..ce3fe7579 --- /dev/null +++ b/benchmark/json_regex_decode/bench_sglang.py @@ -0,0 +1,100 @@ +import argparse +import json +import time + +import sglang as sgl +from sglang.lang.ir import REGEX_INT, REGEX_STRING, REGEX_FLOAT +from sglang.test.test_utils import ( + add_common_sglang_args_and_parse, + select_sglang_backend, +) +from sglang.utils import dump_state_text, read_jsonl + +REGEX_LIST = r"\[(" + REGEX_STRING + ", )*" + REGEX_STRING + r"\]" + +# fmt: off +@sgl.function +def json_warm_up(s): + s += "The information about Hogwarts is in the following JSON format.\n" + with s.var_scope("json_output"): + s += "{\n" + s += ' "name": ' + sgl.gen("name", max_tokens=8, regex=REGEX_STRING + ",") + "\n" + s += ' "country": ' + sgl.gen("country", max_tokens=8, regex=REGEX_STRING + ",") + "\n" + s += ' "latitude": ' + sgl.gen("latitude", max_tokens=8, regex=REGEX_FLOAT + ",") + "\n" + s += ' "population": ' + sgl.gen("population", max_tokens=8, regex=REGEX_INT + ",") + "\n" + s += ' "top 3 landmarks": ' + sgl.gen( "landmarks", max_tokens=24, regex=REGEX_LIST) + "\n" + s += "}\n" + print(f'The warmp up json result is:\n{s["json_output"]}') +# fmt: on + +# fmt: off +@sgl.function +def json_decode(s, document): + s += "Please extract the information of a city from the following wikipedia page.\n" + s += "Page begin.\n" + document + "Page end.\n" + s += "Here is the name, country, and symbol of the city in JSON format.\n" + with s.var_scope("json_output"): + s += "{\n" + s += ' "name": ' + sgl.gen("name", max_tokens=8, regex=REGEX_STRING + ",") + "\n" + s += ' "country": ' + sgl.gen("country", max_tokens=8, regex=REGEX_STRING + ",") + "\n" + s += ' "latitude": ' + sgl.gen("latitude", max_tokens=8, regex=REGEX_FLOAT + ",") + "\n" + s += ' "population": ' + sgl.gen("population", max_tokens=8, regex=REGEX_INT + ",") + "\n" + s += ' "top 3 landmarks": ' + sgl.gen( "landmarks", max_tokens=24, regex=REGEX_LIST) + "\n" + s += "}\n" +# fmt: on + + +def main(args): + lines = read_jsonl(args.data_path) + arguments = [] + for i in range(len(lines[: args.num_questions])): + arguments.append( + { + "document": lines[i]["document"], + } + ) + + # Select backend + backend = select_sglang_backend(args) + sgl.set_default_backend(backend) + + # Warm up + json_warm_up.run().sync() + + # Run requests + tic = time.time() + states = json_decode.run_batch(arguments, temperature=0, num_threads=args.parallel) + for state in states: + state.sync() + latency = time.time() - tic + + # Compute accuracy + print(f"Latency: {latency:.3f}") + + # Write results + dump_state_text(f"tmp_output_{args.backend}.txt", states) + + with open(f"tmp_{args.backend}_json_results.txt", "w") as fout: + for state in states: + fout.write(state["json_output"] + "\n") + + with open(args.result_file, "a") as fout: + value = { + "task": "json_regex_decode", + "backend": args.backend, + "num_gpus": 1, + "latency": round(latency, 3), + "num_requests": args.num_questions, + "other": { + "parallel": args.parallel, + }, + } + fout.write(json.dumps(value) + "\n") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--data-path", type=str, default="questions.jsonl") + parser.add_argument("--num-questions", type=int, default=20) + args = add_common_sglang_args_and_parse(parser) + main(args) diff --git a/benchmark/json_regex_decode/build_dataset.py b/benchmark/json_regex_decode/build_dataset.py new file mode 100644 index 000000000..1396e5ede --- /dev/null +++ b/benchmark/json_regex_decode/build_dataset.py @@ -0,0 +1,58 @@ +import json + +import transformers +import wikipedia + +model_path = "meta-llama/Llama-2-7b-chat-hf" +t = transformers.AutoTokenizer.from_pretrained(model_path) +city_names = [ + "los angles", + "london", + "tokyo", + "beijing", + "singapore", + "paris", + "dubai", + "sydney", + "moscow", + "rome", + "toronto", + "rio de janeiro", + "istanbul", + "berlin", + "auckland", + "buenos aires", + "mexico city", + "mumbai", + "seoul", + "bangkok", + "cairo", + "athens", + "jerusalem", +] + + +def get_content(city_name): + content = str(wikipedia.page(city_name).content) + content = content.replace("\n\n", "\n") + + tokens = t.encode(content) + + expected_tokens = 3000 + truncate_len = int((expected_tokens / len(tokens)) * len(content)) + truncate_content = content[:truncate_len] + truncate_tokens = t.encode(truncate_content) + + # Count token + print( + f"city_name: {city_name}, #tokens: {len(tokens)}, #truncate tokens: {len(truncate_tokens)}" + ) + + return truncate_content + + +if __name__ == "__main__": + with open("questions.jsonl", "w") as fout: + for city_name in city_names: + truncate_content = get_content(city_name) + fout.write(json.dumps({"document": truncate_content}) + "\n") diff --git a/benchmark/multi_turns/README.md b/benchmark/multi_turns/README.md new file mode 100644 index 000000000..f4cc55360 --- /dev/null +++ b/benchmark/multi_turns/README.md @@ -0,0 +1,66 @@ +### Benchmark sglang + +Run llama-7b + +``` +python3 -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000 +``` + +Run mixtral-8x7b +(When there is a CUDA out-of-memory error, try to reduce the `--mem-fraction-static`) + +``` +python3 -m sglang.launch_server --model-path mistralai/Mixtral-8x7B-Instruct-v0.1 --port 30000 --tp-size 8 +``` + +Benchmark(short output) + +``` +python3 bench_sglang.py --tokenizer meta-llama/Llama-2-7b-chat-hf +``` + +Benchmark(long output) + +``` +python3 bench_sglang.py --tokenizer meta-llama/Llama-2-7b-chat-hf --long +``` + +### Benchmark vLLM + +Run llama-7b + +``` +python3 -m vllm.entrypoints.api_server --tokenizer-mode auto --model meta-llama/Llama-2-7b-chat-hf --disable-log-requests --port 21000 +``` + +Run mixtral-8x7b + +``` +python3 -m vllm.entrypoints.api_server --tokenizer-mode auto --model mistralai/Mixtral-8x7B-Instruct-v0.1 --disable-log-requests --port 21000 --tensor-parallel-size 8 +``` + +Benchmark(short output) + +``` +python3 bench_other.py --tokenizer meta-llama/Llama-2-7b-chat-hf --backend vllm +``` + +Benchmark(long output) + +``` +python3 bench_other.py --tokenizer meta-llama/Llama-2-7b-chat-hf --backend vllm --long +``` + +### Benchmark guidance + +Benchmark llama-7b(short output) + +``` +python3 bench_other.py --tokenizer meta-llama/Llama-2-7b-chat-hf --backend guidance --parallel 1 +``` + +Benchmark llama-7b(long output) + +``` +python3 bench_other.py --tokenizer meta-llama/Llama-2-7b-chat-hf --backend guidance --parallel 1 --long +``` \ No newline at end of file diff --git a/benchmark/multi_turns/bench_other.py b/benchmark/multi_turns/bench_other.py new file mode 100644 index 000000000..d7389ff86 --- /dev/null +++ b/benchmark/multi_turns/bench_other.py @@ -0,0 +1,133 @@ +import json +import time +from argparse import ArgumentParser +from concurrent.futures import ThreadPoolExecutor + +import requests +from sglang.test.test_utils import add_common_other_args_and_parse +from sglang.utils import dump_state_text +from tqdm import tqdm +from vllm.transformers_utils.tokenizer import get_tokenizer + +from data_gen import gen_arguments + + +def get_generate(args): + # Select backend + if args.backend == "vllm": + url = f"{args.host}:{args.port}/generate" + + def generate(prompt, max_tokens, stop=None, temperature=0, url=url, n=1): + data = { + "prompt": prompt, + "temperature": temperature, + "max_tokens": max_tokens, + "ignore_eos": True, + "stop": stop, + "stream": False, + "n": n, + } + res = requests.post(url, json=data) + assert res.status_code == 200 + return res.json()["text"][0][len(prompt) :] + + elif args.backend == "guidance": + from guidance import gen, models + + model = models.LlamaCpp( + "/home/ubuntu/model_weights/Llama-2-7b-chat-hf/ggml-model-f16.gguf", + n_gpu_layers=-1, + n_ctx=4096, + ) + + def generate(prompt, max_tokens, stop=None): + out = ( + model + + prompt + + gen(name="answer", max_tokens=max_tokens, temperature=0, stop=stop) + ) + return out["answer"] + + # warmup + for _ in range(3): + generate("Hello!" * 10, max_tokens=64, stop=None) + else: + raise ValueError(f"Invalid backend: {args.backend}") + + return generate + + +def multi_turns(generate, qas): + s = "" + for qa in qas: + s += qa["prompt"] + s += generate(s, max_tokens=qa["new_tokens"]) + + return s + + +def main(args): + print(args) + + tokenizer = get_tokenizer(args.tokenizer, trust_remote_code=args.trust_remote_code) + + multi_qas = gen_arguments(args, tokenizer) + + states = [None] * args.num_qa + + generate = get_generate(args) + + def get_one_answer(i): + states[i] = multi_turns(generate=generate, **multi_qas[i]) + + tic = time.time() + if args.parallel == 1: + for i in tqdm(range(len(multi_qas))): + get_one_answer(i) + else: + with ThreadPoolExecutor(args.parallel) as executor: + rets = executor.map(get_one_answer, list(range(len(multi_qas)))) + for _ in rets: + pass + + latency = time.time() - tic + + # Compute accuracy + print(f"Latency: {latency:.3f}") + + dump_state_text(f"tmp_output_{args.backend}.txt", states) + + with open(args.result_file, "a") as fout: + value = { + "task": "multi_turns", + "backend": args.backend, + "num_gpus": 1, + "latency": round(latency, 3), + "num_requests": args.num_qa, + "num_turns": args.turns, + "other": { + "parallel": args.parallel, + "output_mode": "long" if args.long else "short", + }, + } + fout.write(json.dumps(value) + "\n") + + +if __name__ == "__main__": + parser = ArgumentParser() + parser.add_argument("--turns", type=int, default=4) + parser.add_argument("--num-qa", type=int, default=20) + parser.add_argument("--min-len-q", type=int, default=256) + parser.add_argument("--max-len-q", type=int, default=512) + parser.add_argument("--min-len-a", type=int, default=4) + parser.add_argument("--max-len-a", type=int, default=8) + parser.add_argument("--tokenizer", type=str, required=True) + parser.add_argument("--trust-remote-code", action="store_true") + parser.add_argument("--long", action="store_true") + args = add_common_other_args_and_parse(parser) + + if args.long: + args.min_len_a = 256 + args.max_len_a = 512 + args.num_qa = 20 + main(args) diff --git a/benchmark/multi_turns/bench_sglang.py b/benchmark/multi_turns/bench_sglang.py new file mode 100644 index 000000000..a67beccd4 --- /dev/null +++ b/benchmark/multi_turns/bench_sglang.py @@ -0,0 +1,77 @@ +import json +import time +from argparse import ArgumentParser + +import sglang as sgl +from sglang.test.test_utils import ( + add_common_sglang_args_and_parse, + select_sglang_backend, +) +from sglang.utils import dump_state_text +from vllm.transformers_utils.tokenizer import get_tokenizer + +from data_gen import gen_arguments + + +@sgl.function +def multi_turns(s, qas): + for qa in qas: + s += qa["prompt"] + s += sgl.gen(max_tokens=qa["new_tokens"], ignore_eos=True) + + +def main(args): + print(args) + + tokenizer = get_tokenizer(args.tokenizer, trust_remote_code=args.trust_remote_code) + + multi_qas = gen_arguments(args, tokenizer) + + backend = select_sglang_backend(args) + + tic = time.time() + states = multi_turns.run_batch( + multi_qas, temperature=0, backend=backend, num_threads=args.parallel + ) + for state in states: + state.sync() + latency = time.time() - tic + + print(f"Latency: {latency:.3f}") + + dump_state_text(f"tmp_output_{args.backend}.txt", states) + + with open(args.result_file, "a") as fout: + value = { + "task": "multi_turns", + "backend": args.backend, + "num_gpus": 1, + "latency": round(latency, 3), + "num_requests": args.num_qa, + "num_turns": args.turns, + "other": { + "parallel": args.parallel, + "output_mode": "long" if args.long else "short", + }, + } + fout.write(json.dumps(value) + "\n") + + +if __name__ == "__main__": + parser = ArgumentParser() + parser.add_argument("--turns", type=int, default=4) + parser.add_argument("--num-qa", type=int, default=20) + parser.add_argument("--min-len-q", type=int, default=256) + parser.add_argument("--max-len-q", type=int, default=512) + parser.add_argument("--min-len-a", type=int, default=4) + parser.add_argument("--max-len-a", type=int, default=8) + parser.add_argument("--tokenizer", type=str, required=True) + parser.add_argument("--trust-remote-code", action="store_true") + parser.add_argument("--long", action="store_true") + args = add_common_sglang_args_and_parse(parser) + + if args.long: + args.min_len_a = 256 + args.max_len_a = 512 + args.num_qa = 20 + main(args) diff --git a/benchmark/multi_turns/data_gen.py b/benchmark/multi_turns/data_gen.py new file mode 100644 index 000000000..043c07a76 --- /dev/null +++ b/benchmark/multi_turns/data_gen.py @@ -0,0 +1,29 @@ +import random +import string + +random.seed(42) + + +def gen_prompt(tokenizer, token_num): + cha_set = string.ascii_letters + string.digits + ret = "".join(random.choices(cha_set, k=token_num)) + while len(tokenizer(ret).input_ids) < token_num: + ret += random.choice(cha_set) + return ret + + +def gen_arguments(args, tokenizer): + multi_qas = [{"qas": []} for _ in range(args.num_qa)] + for i in range(args.num_qa): + qas = multi_qas[i]["qas"] + for _ in range(args.turns): + prompt_len = random.randint(args.min_len_q, args.max_len_q) + new_tokens = random.randint(args.min_len_a, args.max_len_a) + qas.append( + { + "prompt": gen_prompt(tokenizer, prompt_len), + "new_tokens": new_tokens, + } + ) + + return multi_qas diff --git a/python/sglang/api.py b/python/sglang/api.py index 2e35d8888..6dde54c59 100644 --- a/python/sglang/api.py +++ b/python/sglang/api.py @@ -37,6 +37,7 @@ def gen( top_k: Optional[int] = None, frequency_penalty: Optional[float] = None, presence_penalty: Optional[float] = None, + ignore_eos: Optional[bool] = None, dtype: Optional[type] = None, choices: Optional[List[str]] = None, regex: Optional[str] = None, @@ -60,6 +61,7 @@ def gen( top_k, frequency_penalty, presence_penalty, + ignore_eos, dtype, regex, ) @@ -74,6 +76,7 @@ def gen_int( top_k: Optional[int] = None, frequency_penalty: Optional[float] = None, presence_penalty: Optional[float] = None, + ignore_eos: Optional[bool] = None, ): return SglGen( name, @@ -84,6 +87,7 @@ def gen_int( top_k, frequency_penalty, presence_penalty, + ignore_eos, int, None, ) @@ -98,6 +102,7 @@ def gen_string( top_k: Optional[int] = None, frequency_penalty: Optional[float] = None, presence_penalty: Optional[float] = None, + ignore_eos: Optional[bool] = None, ): return SglGen( name, @@ -108,6 +113,7 @@ def gen_string( top_k, frequency_penalty, presence_penalty, + ignore_eos, str, None, ) diff --git a/python/sglang/backend/anthropic.py b/python/sglang/backend/anthropic.py index 77d7f5127..c561f3b5d 100644 --- a/python/sglang/backend/anthropic.py +++ b/python/sglang/backend/anthropic.py @@ -4,7 +4,7 @@ import numpy as np from sglang.backend.base_backend import BaseBackend from sglang.lang.chat_template import get_chat_template from sglang.lang.interpreter import StreamExecutor -from sglang.lang.ir import SamplingParams +from sglang.lang.ir import SglSamplingParams try: import anthropic @@ -28,7 +28,7 @@ class Anthropic(BaseBackend): def generate( self, s: StreamExecutor, - sampling_params: SamplingParams, + sampling_params: SglSamplingParams, ): prompt = s.text_ ret = anthropic.Anthropic().completions.create( @@ -43,7 +43,7 @@ class Anthropic(BaseBackend): def generate_stream( self, s: StreamExecutor, - sampling_params: SamplingParams, + sampling_params: SglSamplingParams, ): prompt = s.text_ generator = anthropic.Anthropic().completions.create( diff --git a/python/sglang/backend/base_backend.py b/python/sglang/backend/base_backend.py index 7f59f5b15..0bbf3ef3e 100644 --- a/python/sglang/backend/base_backend.py +++ b/python/sglang/backend/base_backend.py @@ -2,7 +2,7 @@ from typing import Callable, List, Optional, Union from sglang.lang.chat_template import get_chat_template from sglang.lang.interpreter import StreamExecutor -from sglang.lang.ir import SamplingParams +from sglang.lang.ir import SglSamplingParams class BaseBackend: @@ -48,14 +48,14 @@ class BaseBackend: def generate( self, s: StreamExecutor, - sampling_params: SamplingParams, + sampling_params: SglSamplingParams, ): raise NotImplementedError() def generate_stream( self, s: StreamExecutor, - sampling_params: SamplingParams, + sampling_params: SglSamplingParams, ): raise NotImplementedError() diff --git a/python/sglang/backend/openai.py b/python/sglang/backend/openai.py index b22e149c9..1b23b5f96 100644 --- a/python/sglang/backend/openai.py +++ b/python/sglang/backend/openai.py @@ -4,7 +4,7 @@ import numpy as np from sglang.backend.base_backend import BaseBackend from sglang.lang.chat_template import get_chat_template from sglang.lang.interpreter import StreamExecutor -from sglang.lang.ir import SamplingParams +from sglang.lang.ir import SglSamplingParams try: import openai @@ -73,7 +73,7 @@ class OpenAI(BaseBackend): def generate( self, s: StreamExecutor, - sampling_params: SamplingParams, + sampling_params: SglSamplingParams, ): if sampling_params.dtype is None: if self.is_chat_model: @@ -122,7 +122,7 @@ class OpenAI(BaseBackend): def generate_stream( self, s: StreamExecutor, - sampling_params: SamplingParams, + sampling_params: SglSamplingParams, ): if sampling_params.dtype is None: if self.is_chat_model: diff --git a/python/sglang/backend/runtime_endpoint.py b/python/sglang/backend/runtime_endpoint.py index 9d15be72a..9773d4b39 100644 --- a/python/sglang/backend/runtime_endpoint.py +++ b/python/sglang/backend/runtime_endpoint.py @@ -7,7 +7,7 @@ from sglang.backend.base_backend import BaseBackend from sglang.global_config import global_config from sglang.lang.chat_template import get_chat_template_by_model_path from sglang.lang.interpreter import StreamExecutor -from sglang.lang.ir import SamplingParams, SglArgument +from sglang.lang.ir import SglSamplingParams, SglArgument from sglang.utils import encode_image_base64, find_printable_text, http_request @@ -55,7 +55,7 @@ class RuntimeEndpoint(BaseBackend): def generate( self, s: StreamExecutor, - sampling_params: SamplingParams, + sampling_params: SglSamplingParams, ): if sampling_params.dtype is None: data = { @@ -87,7 +87,7 @@ class RuntimeEndpoint(BaseBackend): def generate_stream( self, s: StreamExecutor, - sampling_params: SamplingParams, + sampling_params: SglSamplingParams, ): if sampling_params.dtype is None: data = { diff --git a/python/sglang/backend/tgi.py b/python/sglang/backend/tgi.py index e5462218d..be3f3fea4 100644 --- a/python/sglang/backend/tgi.py +++ b/python/sglang/backend/tgi.py @@ -7,7 +7,7 @@ from typing import List, Optional, Union from sglang.backend.base_backend import BaseBackend from sglang.lang.chat_template import get_chat_template_by_model_path from sglang.lang.interpreter import StreamExecutor -from sglang.lang.ir import SamplingParams +from sglang.lang.ir import SglSamplingParams from sglang.utils import http_request @@ -138,7 +138,7 @@ class TGI(BaseBackend): self, s: StreamExecutor, choices: List[str], - sampling_params: SamplingParams, + sampling_params: SglSamplingParams, ): decision = self.retry_for_expected( s.text_, @@ -152,7 +152,7 @@ class TGI(BaseBackend): s: StreamExecutor, max_tokens: int, stop: Union[str, List[str]], - sampling_params: SamplingParams, + sampling_params: SglSamplingParams, dtype: Optional[str] = None, ): if dtype is None: diff --git a/python/sglang/lang/compiler.py b/python/sglang/lang/compiler.py index 0d1ba68c6..8f9259096 100644 --- a/python/sglang/lang/compiler.py +++ b/python/sglang/lang/compiler.py @@ -6,7 +6,7 @@ from typing import List, Union from sglang.global_config import global_config from sglang.lang.interpreter import ProgramState, StreamExecutor, pin_program from sglang.lang.ir import ( - SamplingParams, + SglSamplingParams, SglArgument, SglConstantText, SglExpr, @@ -140,7 +140,7 @@ class CompiledFunction: kwargs = {k: SglArgument(k, v) for k, v in kwargs.items()} kwargs.update(self.function.bind_arguments) - default_sampling_para = SamplingParams( + default_sampling_para = SglSamplingParams( max_new_tokens=max_new_tokens, stop=stop, temperature=temperature, @@ -173,7 +173,7 @@ class CompiledFunction: backend = backend or global_config.default_backend - default_sampling_para = SamplingParams( + default_sampling_para = SglSamplingParams( max_new_tokens=max_new_tokens, stop=stop, temperature=temperature, diff --git a/python/sglang/lang/interpreter.py b/python/sglang/lang/interpreter.py index 175d2afa9..9486d4406 100644 --- a/python/sglang/lang/interpreter.py +++ b/python/sglang/lang/interpreter.py @@ -292,7 +292,7 @@ class StreamExecutor: assert isinstance(other, SglExpr), f"{other}" - if isinstance(other, (SglConstantText, SglArgument)): + if isinstance(other, SglConstantText): self._execute_fill(other.value) elif isinstance(other, SglGen): self._execute_gen(other) @@ -332,8 +332,6 @@ class StreamExecutor: def _execute_image(self, expr: SglImage): path = expr.path - if isinstance(path, SglArgument): - path = path.value base64_data = encode_image_base64(path) @@ -419,7 +417,7 @@ class StreamExecutor: "role": expr.role, "content": [{"type": "text", "text": new_text}], } - for (image_path, image_base64_data) in self.cur_images: + for image_path, image_base64_data in self.cur_images: last_msg["content"].append( { "type": "image_url", @@ -480,6 +478,7 @@ class StreamExecutor: "top_k", "frequency_penalty", "presence_penalty", + "ignore_eos", "dtype", "regex", ]: diff --git a/python/sglang/lang/ir.py b/python/sglang/lang/ir.py index 09cf9ad2a..33612c6b5 100644 --- a/python/sglang/lang/ir.py +++ b/python/sglang/lang/ir.py @@ -13,7 +13,7 @@ REGEX_STRING = r"\"[\w\d\s]*\"" # bugs with regex r"\".*\"" in interegular pkg @dataclasses.dataclass -class SamplingParams: +class SglSamplingParams: max_new_tokens: int = 16 stop: Union[str, List[str]] = () temperature: float = 1.0 @@ -21,13 +21,14 @@ class SamplingParams: top_k: int = -1 # -1 means disable frequency_penalty: float = 0.0 presence_penalty: float = 0.0 + ignore_eos: bool = False # for constrained generation, not included in to_xxx_kwargs dtype: Optional[str] = None regex: Optional[str] = None def clone(self): - return SamplingParams( + return SglSamplingParams( self.max_new_tokens, self.stop, self.temperature, @@ -67,6 +68,7 @@ class SamplingParams: "top_k": self.top_k, "frequency_penalty": self.frequency_penalty, "presence_penalty": self.presence_penalty, + "ignore_eos": self.ignore_eos, "regex": self.regex, } @@ -98,13 +100,14 @@ class SglFunction: top_k: int = -1, frequency_penalty: float = 0.0, presence_penalty: float = 0.0, + ignore_eos: bool = False, stream: bool = False, backend=None, **kwargs, ): from sglang.lang.interpreter import run_program - default_sampling_para = SamplingParams( + default_sampling_para = SglSamplingParams( max_new_tokens=max_new_tokens, stop=stop, temperature=temperature, @@ -112,9 +115,9 @@ class SglFunction: top_k=top_k, frequency_penalty=frequency_penalty, presence_penalty=presence_penalty, + ignore_eos=ignore_eos, ) backend = backend or global_config.default_backend - kwargs = {k: SglArgument(k, v) for k, v in kwargs.items()} return run_program(self, backend, args, kwargs, default_sampling_para, stream) def run_batch( @@ -128,6 +131,7 @@ class SglFunction: top_k: int = -1, frequency_penalty: float = 0.0, presence_penalty: float = 0.0, + ignore_eos: bool = False, backend=None, num_threads: Union[str, int] = "auto", progress_bar: bool = False, @@ -139,7 +143,7 @@ class SglFunction: return [] assert isinstance(batch_kwargs[0], dict) - default_sampling_para = SamplingParams( + default_sampling_para = SglSamplingParams( max_new_tokens=max_new_tokens, stop=stop, temperature=temperature, @@ -147,11 +151,9 @@ class SglFunction: top_k=top_k, frequency_penalty=frequency_penalty, presence_penalty=presence_penalty, + ignore_eos=ignore_eos, ) backend = backend or global_config.default_backend - batch_kwargs = [ - {k: SglArgument(k, v) for k, v in kwargs.items()} for kwargs in batch_kwargs - ] return run_program_batch( self, backend, @@ -321,12 +323,13 @@ class SglGen(SglExpr): top_k, frequency_penalty, presence_penalty, + ignore_eos, dtype, regex, ): super().__init__() self.name = name - self.sampling_params = SamplingParams( + self.sampling_params = SglSamplingParams( max_new_tokens=max_new_tokens, stop=stop, temperature=temperature, @@ -334,6 +337,7 @@ class SglGen(SglExpr): top_k=top_k, frequency_penalty=frequency_penalty, presence_penalty=presence_penalty, + ignore_eos=ignore_eos, dtype=dtype, regex=regex, ) diff --git a/python/sglang/lang/tracer.py b/python/sglang/lang/tracer.py index 7e89320b5..1ab1e8238 100644 --- a/python/sglang/lang/tracer.py +++ b/python/sglang/lang/tracer.py @@ -40,7 +40,8 @@ def extract_prefix_by_tracing(program, backend): try: with TracingScope(tracer): tracer.ret_value = program.func(tracer, **arguments) - except StopTracing: + except (StopTracing, TypeError): + # Some exceptions may not be catched pass # Run and cache prefix diff --git a/python/sglang/srt/backend_config.py b/python/sglang/srt/backend_config.py new file mode 100644 index 000000000..b29f08387 --- /dev/null +++ b/python/sglang/srt/backend_config.py @@ -0,0 +1,12 @@ +""" +Backend configurations, may vary with different serving platforms. +""" +from dataclasses import dataclass + + +@dataclass +class BackendConfig: + extend_dependency_time: float = 0.03 + + +GLOBAL_BACKEND_CONFIG = BackendConfig() diff --git a/python/sglang/srt/managers/router/manager.py b/python/sglang/srt/managers/router/manager.py index 8ac027dfa..b6abc77c5 100644 --- a/python/sglang/srt/managers/router/manager.py +++ b/python/sglang/srt/managers/router/manager.py @@ -1,6 +1,5 @@ import asyncio import logging -from typing import List, Tuple import uvloop import zmq @@ -8,6 +7,7 @@ import zmq.asyncio from sglang.srt.managers.router.model_rpc import ModelRpcClient from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.utils import get_exception_traceback +from sglang.srt.backend_config import GLOBAL_BACKEND_CONFIG asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) @@ -28,6 +28,9 @@ class RouterManager: self.model_client = model_client self.recv_reqs = [] + # Init Some Configs + self.extend_dependency_time = GLOBAL_BACKEND_CONFIG.extend_dependency_time + async def loop_for_forward(self): while True: next_step_input = list(self.recv_reqs) @@ -37,7 +40,12 @@ class RouterManager: for obj in out_pyobjs: self.send_to_detokenizer.send_pyobj(obj) - # await for a while to accept input requests + # async sleep for recving the subsequent request, and avoiding cache miss + if len(out_pyobjs) != 0: + has_finished = any([obj.finished for obj in out_pyobjs]) + if has_finished: + await asyncio.sleep(self.extend_dependency_time) + await asyncio.sleep(0.001) async def loop_for_recv_requests(self): diff --git a/python/sglang/srt/managers/router/model_rpc.py b/python/sglang/srt/managers/router/model_rpc.py index 78080fcb3..5aec63311 100644 --- a/python/sglang/srt/managers/router/model_rpc.py +++ b/python/sglang/srt/managers/router/model_rpc.py @@ -19,7 +19,6 @@ from sglang.srt.managers.router.model_runner import ModelRunner from sglang.srt.managers.router.radix_cache import RadixCache from sglang.srt.managers.router.scheduler import Scheduler from sglang.srt.model_config import ModelConfig -from sglang.srt.sampling_params import SamplingParams from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.utils import ( get_exception_traceback, @@ -158,6 +157,18 @@ class ModelRpcServer(rpyc.Service): if self.running_batch.is_empty(): self.running_batch = None break + else: + # check the available size + available_size = ( + self.token_to_kv_pool.available_size() + + self.tree_cache.evictable_size() + ) + if available_size != self.max_total_num_token: + logger.warning( + "Warning: " + f"available_size={available_size}, max_total_num_token={self.max_total_num_token}\n" + "KV cache pool leak detected!" + ) if self.running_batch is not None and self.tp_rank == 0: if self.decode_forward_ct >= 20: @@ -408,7 +419,9 @@ class ModelRpcServer(rpyc.Service): token_ids = tuple(req.input_ids + req.output_ids) seq_len = len(token_ids) - 1 indices = self.req_to_token_pool.req_to_token[req_pool_idx, :seq_len] - prefix_len = self.tree_cache.insert(token_ids, indices.clone()) + prefix_len = self.tree_cache.insert( + token_ids[:seq_len], indices.clone() + ) self.token_to_kv_pool.free(indices[:prefix_len]) self.req_to_token_pool.free(req_pool_idx) diff --git a/python/sglang/srt/managers/router/scheduler.py b/python/sglang/srt/managers/router/scheduler.py index 1376b329a..582268f60 100644 --- a/python/sglang/srt/managers/router/scheduler.py +++ b/python/sglang/srt/managers/router/scheduler.py @@ -18,7 +18,7 @@ class Scheduler: self.tree_cache = tree_cache def new_token_estimation_ratio(self): - return 0.4 if self.schedule_heuristic != "fcfs" else 0.5 + return 0.5 if self.schedule_heuristic != "fcfs" else 0.6 def get_priority_queue(self, forward_queue): if self.schedule_heuristic == "lpm": diff --git a/python/sglang/srt/sampling_params.py b/python/sglang/srt/sampling_params.py index e80766984..89b89d57d 100644 --- a/python/sglang/srt/sampling_params.py +++ b/python/sglang/srt/sampling_params.py @@ -7,13 +7,13 @@ _SAMPLING_EPS = 1e-6 class SamplingParams: def __init__( self, + max_new_tokens: int = 16, + stop: Optional[Union[str, List[str]]] = None, temperature: float = 1.0, top_p: float = 1.0, top_k: int = -1, frequency_penalty: float = 0.0, presence_penalty: float = 0.0, - stop: Optional[Union[str, List[str]]] = None, - max_new_tokens: int = 16, ignore_eos: bool = False, skip_special_tokens: bool = True, dtype: Optional[str] = None, diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 6f35a8c82..51d35069c 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -24,6 +24,8 @@ class ServerArgs: def __post_init__(self): if self.tokenizer_path is None: self.tokenizer_path = self.model_path + if self.tp_size > 1: + self.mem_fraction_static = 0.8 @staticmethod def add_cli_args(parser: argparse.ArgumentParser): diff --git a/python/sglang/test/test_utils.py b/python/sglang/test/test_utils.py index 161953159..65ca36c2f 100644 --- a/python/sglang/test/test_utils.py +++ b/python/sglang/test/test_utils.py @@ -38,6 +38,26 @@ def call_generate_vllm(prompt, temperature, max_tokens, stop, url, n=1): return pred +def call_generate_outlines( + prompt, temperature, max_tokens, url, stop=[], regex=None, n=1 +): + data = { + "prompt": prompt, + "temperature": temperature, + "max_tokens": max_tokens, + "stop": stop, + "regex": regex, + "n": n, + } + res = requests.post(url, json=data) + assert res.status_code == 200 + if n == 1: + pred = res.json()["text"][0][len(prompt) :] + else: + pred = [x[len(prompt) :] for x in res.json()["text"]] + return pred + + def call_generate_srt_raw(prompt, temperature, max_tokens, stop, url): data = { "text": prompt, diff --git a/python/sglang/utils.py b/python/sglang/utils.py index c3a40b0ce..ac103415e 100644 --- a/python/sglang/utils.py +++ b/python/sglang/utils.py @@ -67,7 +67,7 @@ def dump_state_text(filename, states, mode="w"): if isinstance(s, str): pass elif isinstance(s, ProgramState): - s = s.text().strip() + s = s.text() else: s = str(s)