From 55c16436273d4a42f7cfe342df5f10ad05a8d0fe Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Sun, 26 May 2024 12:51:45 -0700 Subject: [PATCH] Improve benchmark scripts & rename some scripts (#477) --- benchmark/gsm8k/bench_other.py | 1 + .../latency_throughput/bench_throughput.py | 81 ++++++++++++++----- benchmark/latency_throughput/test_latency.py | 34 ++++++-- python/sglang/srt/hf_transformers_utils.py | 4 +- .../sglang/srt/managers/router/infer_batch.py | 3 +- .../sglang/srt/managers/router/model_rpc.py | 41 +++++----- .../srt/managers/router/model_runner.py | 8 +- .../sglang/srt/managers/router/scheduler.py | 12 +-- python/sglang/srt/server_args.py | 13 ++- python/sglang/test/test_utils.py | 26 ++++++ 10 files changed, 161 insertions(+), 62 deletions(-) diff --git a/benchmark/gsm8k/bench_other.py b/benchmark/gsm8k/bench_other.py index 2815a079e..8d61858c6 100644 --- a/benchmark/gsm8k/bench_other.py +++ b/benchmark/gsm8k/bench_other.py @@ -65,6 +65,7 @@ def main(args): def get_one_answer(i): answer = call_generate( prompt=few_shot_examples + questions[i], + #prompt="System: " + few_shot_examples + "<|separator|>\n\n" + questions[i], temperature=0, max_tokens=256, stop="Question", diff --git a/benchmark/latency_throughput/bench_throughput.py b/benchmark/latency_throughput/bench_throughput.py index 719eca12c..f058ecad4 100644 --- a/benchmark/latency_throughput/bench_throughput.py +++ b/benchmark/latency_throughput/bench_throughput.py @@ -26,8 +26,7 @@ from typing import AsyncGenerator, List, Tuple import aiohttp import numpy as np from tqdm.asyncio import tqdm_asyncio -from transformers import PreTrainedTokenizerBase -from vllm.transformers_utils.tokenizer import get_tokenizer +from transformers import AutoTokenizer # (prompt len, output len, latency) REQUEST_LATENCY: List[Tuple[int, int, float]] = [] @@ -36,7 +35,7 @@ REQUEST_LATENCY: List[Tuple[int, int, float]] = [] def sample_requests( dataset_path: str, num_requests: int, - tokenizer: PreTrainedTokenizerBase, + tokenizer: AutoTokenizer, ) -> List[Tuple[str, int, int]]: # Load the dataset. with open(dataset_path) as f: @@ -150,22 +149,47 @@ async def send_request( "inputs": prompt, "parameters": params, } + elif backend == "xinfer": + pass else: raise ValueError(f"Unknown backend: {backend}") - timeout = aiohttp.ClientTimeout(total=3 * 3600) - async with aiohttp.ClientSession(timeout=timeout) as session: - while True: - async with session.post(api_url, headers=headers, json=pload) as response: - chunks = [] - async for chunk, _ in response.content.iter_chunks(): - chunks.append(chunk) - output = b"".join(chunks).decode("utf-8") - output = json.loads(output) + if backend != "xinfer": + timeout = aiohttp.ClientTimeout(total=3 * 3600) + async with aiohttp.ClientSession(timeout=timeout) as session: + while True: + async with session.post(api_url, headers=headers, json=pload) as response: + chunks = [] + async for chunk, _ in response.content.iter_chunks(): + chunks.append(chunk) + output = b"".join(chunks).decode("utf-8") + output = json.loads(output) - # Re-send the request if it failed. - if "error" not in output: - break + # Re-send the request if it failed. + if "error" not in output: + break + else: + print(output) + else: + import grpc + from xlm.proto import sampler_pb2, sampler_pb2_grpc + + api_url = api_url.replace("http://", "").replace("/generate", "") + sampler_channel = grpc.aio.insecure_channel(api_url) + sampler = sampler_pb2_grpc.SamplerStub(sampler_channel) + + request_end_time = time.perf_counter() + sample_request = sampler_pb2.SampleTextRequest( + prompt=prompt, + settings=sampler_pb2.SampleSettings( + max_len=output_len, + rng_seed=0, + temperature=0, + nucleus_p=1, + ), + ) + stream = sampler.SampleText(sample_request) + response = "".join([x.text async for x in stream]) request_end_time = time.perf_counter() request_latency = request_end_time - request_start_time @@ -204,8 +228,18 @@ def main(args: argparse.Namespace): np.random.seed(args.seed) api_url = f"http://{args.host}:{args.port}/generate" - tokenizer = get_tokenizer(args.tokenizer, trust_remote_code=args.trust_remote_code) - input_requests = sample_requests(args.dataset, args.num_prompts, tokenizer) + tokenizer = AutoTokenizer.from_pretrained(args.tokenizer, trust_remote_code=args.trust_remote_code) + + if args.dataset: + input_requests = sample_requests(args.dataset, args.num_prompts, tokenizer) + else: + input_lens = np.random.randint(args.input_len * args.range_ratio, args.input_len + 1, size=args.num_prompts) + output_lens = np.random.randint(args.output_len * args.range_ratio, args.output_len + 1, size=args.num_prompts) + offsets = np.random.randint(0, tokenizer.vocab_size, size=args.num_prompts) + input_requests = [] + for i in range(args.num_prompts): + prompt = tokenizer.decode([(offsets[i] + i + j) % tokenizer.vocab_size for j in range(input_lens[i])]) + input_requests.append((prompt, int(input_lens[i]), int(output_lens[i]))) benchmark_start_time = time.perf_counter() asyncio.run( @@ -246,16 +280,21 @@ if __name__ == "__main__": parser.add_argument( "--backend", type=str, - default="vllm", - choices=["vllm", "tgi", "srt", "lightllm"], + default="srt", + choices=["vllm", "tgi", "srt", "lightllm", "xinfer"], ) parser.add_argument("--host", type=str, default="localhost") parser.add_argument("--port", type=int, default=8000) parser.add_argument( - "--dataset", type=str, required=True, help="Path to the dataset." + "--dataset", type=str, help="Path to the dataset." ) + parser.add_argument("--input-len", type=str, default=1024) + parser.add_argument("--output-len", type=str, default=128) + parser.add_argument("--range-ratio", type=float, default=1.0) parser.add_argument( - "--tokenizer", type=str, required=True, help="Name or path of the tokenizer." + "--tokenizer", type=str, + default="NousResearch/Meta-Llama-3-8B", + help="Name or path of the tokenizer." ) parser.add_argument( "--best-of", diff --git a/benchmark/latency_throughput/test_latency.py b/benchmark/latency_throughput/test_latency.py index f65f390e9..37ab6aef6 100644 --- a/benchmark/latency_throughput/test_latency.py +++ b/benchmark/latency_throughput/test_latency.py @@ -18,20 +18,22 @@ if __name__ == "__main__": args.port = 21000 elif args.backend == "lightllm": args.port = 22000 + elif args.backend == "xinfer": + args.port = 9988 else: raise ValueError(f"Invalid backend: {args.backend}") url = f"{args.host}:{args.port}" a = random.randint(0, 1 << 20) max_new_tokens = 256 + prompt = f"{a, }" tic = time.time() if args.backend == "srt": response = requests.post( url + "/generate", json={ - "text": f"The capital of France is", - # "input_ids": [[2] * 256] * 196, + "text": prompt, "sampling_params": { "temperature": 0, "max_new_tokens": max_new_tokens, @@ -42,7 +44,7 @@ if __name__ == "__main__": response = requests.post( url + "/generate", json={ - "inputs": f"{a}, ", + "inputs": prompt, "parameters": { "temperature": 0, "max_new_tokens": max_new_tokens, @@ -53,14 +55,36 @@ if __name__ == "__main__": response = requests.post( url + "/generate", json={ - "prompt": f"{a}, ", + "prompt": prompt, "temperature": 0, "max_tokens": max_new_tokens, }, ) + elif args.backend == "xinfer": + import grpc + from xlm.proto import sampler_pb2, sampler_pb2_grpc + + sampler_channel = grpc.insecure_channel(url.replace("http://", "")) + sampler = sampler_pb2_grpc.SamplerStub(sampler_channel) + + tic = time.time() + sample_request = sampler_pb2.SampleTextRequest( + prompt=prompt, + settings=sampler_pb2.SampleSettings( + max_len=max_new_tokens, + rng_seed=0, + temperature=0, + nucleus_p=1, + ), + ) + stream = sampler.SampleText(sample_request) + response = "".join([x.text for x in stream]) latency = time.time() - tic - ret = response.json() + if isinstance(response, str): + ret = response + else: + ret = response.json() print(ret) speed = max_new_tokens / latency diff --git a/python/sglang/srt/hf_transformers_utils.py b/python/sglang/srt/hf_transformers_utils.py index b34168462..f5c1654aa 100644 --- a/python/sglang/srt/hf_transformers_utils.py +++ b/python/sglang/srt/hf_transformers_utils.py @@ -183,13 +183,13 @@ class TiktokenTokenizer: self.eos_token_id = tokenizer.eos_token self.vocab_size = tokenizer.n_vocab - def encode(self, x): + def encode(self, x, add_special_tokens=False): return self.tokenizer.encode(x) def decode(self, x): return self.tokenizer.decode(x) - def batch_decode(self, batch, skip_special_tokens, spaces_between_special_tokens): + def batch_decode(self, batch, skip_special_tokens=True, spaces_between_special_tokens=False): return self.tokenizer.decode_batch(batch) def convert_ids_to_tokens(self, index): diff --git a/python/sglang/srt/managers/router/infer_batch.py b/python/sglang/srt/managers/router/infer_batch.py index 20cc662a0..fb4afa332 100644 --- a/python/sglang/srt/managers/router/infer_batch.py +++ b/python/sglang/srt/managers/router/infer_batch.py @@ -66,6 +66,7 @@ class Req: self.finish_reason = None self.hit_stop_str = None + # Prefix info self.extend_input_len = 0 self.prefix_indices = [] self.last_node = None @@ -76,8 +77,8 @@ class Req: self.top_logprobs_num = 0 self.normalized_prompt_logprob = None self.prefill_token_logprobs = None - self.decode_token_logprobs = [] self.prefill_top_logprobs = None + self.decode_token_logprobs = [] self.decode_top_logprobs = [] # The tokens is prefilled but need to be considered as decode tokens # and should be updated for the decode logprobs diff --git a/python/sglang/srt/managers/router/model_rpc.py b/python/sglang/srt/managers/router/model_rpc.py index d52b3767d..4a4093525 100644 --- a/python/sglang/srt/managers/router/model_rpc.py +++ b/python/sglang/srt/managers/router/model_rpc.py @@ -91,26 +91,27 @@ class ModelRpcServer: tokenizer_mode=server_args.tokenizer_mode, trust_remote_code=server_args.trust_remote_code, ) - self.max_total_num_token = self.model_runner.max_total_num_token - self.max_num_running_seq = self.max_total_num_token // 2 - self.max_prefill_num_token = max( + self.max_total_num_tokens = self.model_runner.max_total_num_tokens + self.max_prefill_tokens = max( self.model_config.context_len, ( - self.max_total_num_token // 6 - if server_args.max_prefill_num_token is None - else server_args.max_prefill_num_token + self.max_total_num_tokens // 6 + if server_args.max_prefill_tokens is None + else server_args.max_prefill_tokens ), ) + self.max_running_requests = (self.max_total_num_tokens // 2 + if server_args.max_running_requests is None else server_args.max_running_requests) + self.int_token_logit_bias = torch.tensor( get_int_token_logit_bias(self.tokenizer, self.model_config.vocab_size) ) set_random_seed(server_args.random_seed) # Print info - logger.info( - f"[rank={self.tp_rank}] " - f"max_total_num_token={self.max_total_num_token}, " - f"max_prefill_num_token={self.max_prefill_num_token}, " + logger.info(f"[rank={self.tp_rank}] " + f"max_total_num_tokens={self.max_total_num_tokens}, " + f"max_prefill_tokens={self.max_prefill_tokens}, " f"context_len={self.model_config.context_len}, " ) if self.tp_rank == 0: @@ -125,9 +126,9 @@ class ModelRpcServer: self.tree_cache_metrics = {"total": 0, "hit": 0} self.scheduler = Scheduler( self.schedule_heuristic, - self.max_num_running_seq, - self.max_prefill_num_token, - self.max_total_num_token, + self.max_running_requests, + self.max_prefill_tokens, + self.max_total_num_tokens, self.tree_cache, ) self.req_to_token_pool = self.model_runner.req_to_token_pool @@ -219,7 +220,7 @@ class ModelRpcServer: # Print stats if self.tp_rank == 0: if self.decode_forward_ct % 40 == 0: - num_used = self.max_total_num_token - ( + num_used = self.max_total_num_tokens - ( self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size() ) @@ -231,7 +232,7 @@ class ModelRpcServer: logger.info( f"#running-req: {len(self.running_batch.reqs)}, " f"#token: {num_used}, " - f"token usage: {num_used / self.max_total_num_token:.2f}, " + f"token usage: {num_used / self.max_total_num_tokens:.2f}, " f"gen throughput (token/s): {throuhgput:.2f}, " f"#queue-req: {len(self.forward_queue)}" ) @@ -248,10 +249,10 @@ class ModelRpcServer: self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size() ) - if available_size != self.max_total_num_token: + if available_size != self.max_total_num_tokens: warnings.warn( "Warning: " - f"available_size={available_size}, max_total_num_token={self.max_total_num_token}\n" + f"available_size={available_size}, max_total_num_tokens={self.max_total_num_tokens}\n" "KV cache pool leak detected!" ) @@ -297,14 +298,14 @@ class ModelRpcServer: req.sampling_params.max_new_tokens = min( req.sampling_params.max_new_tokens, self.model_config.context_len - 1 - len(req.origin_input_ids), - self.max_total_num_token - 128 - len(req.origin_input_ids), + self.max_total_num_tokens - 128 - len(req.origin_input_ids), ) self.forward_queue.append(req) def get_new_fill_batch(self): if ( self.running_batch is not None - and len(self.running_batch.reqs) > self.max_num_running_seq + and len(self.running_batch.reqs) > self.max_running_requests ): return None @@ -360,7 +361,7 @@ class ModelRpcServer: req.extend_input_len + req.max_new_tokens() + new_batch_total_tokens < available_size and req.extend_input_len + new_batch_input_tokens - < self.max_prefill_num_token + < self.max_prefill_tokens ): delta = self.tree_cache.inc_lock_ref(req.last_node) available_size += delta diff --git a/python/sglang/srt/managers/router/model_runner.py b/python/sglang/srt/managers/router/model_runner.py index faa7ab927..fa1190b04 100644 --- a/python/sglang/srt/managers/router/model_runner.py +++ b/python/sglang/srt/managers/router/model_runner.py @@ -301,19 +301,19 @@ class ModelRunner: return max_num_token def init_memory_pool(self, total_gpu_memory): - self.max_total_num_token = self.profile_max_num_token(total_gpu_memory) + self.max_total_num_tokens = self.profile_max_num_token(total_gpu_memory) - if self.max_total_num_token <= 0: + if self.max_total_num_tokens <= 0: raise RuntimeError( "Not enought memory. " "Please try to increase --mem-fraction-static." ) self.req_to_token_pool = ReqToTokenPool( - int(self.max_total_num_token / self.model_config.context_len * 256), + int(self.max_total_num_tokens / self.model_config.context_len * 256), self.model_config.context_len + 8, ) self.token_to_kv_pool = TokenToKVPool( - self.max_total_num_token, + self.max_total_num_tokens, dtype=torch.float16, head_num=self.model_config.num_key_value_heads // self.tp_size, head_dim=self.model_config.head_dim, diff --git a/python/sglang/srt/managers/router/scheduler.py b/python/sglang/srt/managers/router/scheduler.py index 806151931..def11e775 100644 --- a/python/sglang/srt/managers/router/scheduler.py +++ b/python/sglang/srt/managers/router/scheduler.py @@ -6,15 +6,15 @@ class Scheduler: def __init__( self, schedule_heuristic, - max_running_seq, - max_prefill_num_token, - max_total_num_token, + max_running_seqs, + max_prefill_num_tokens, + max_total_num_tokens, tree_cache, ): self.schedule_heuristic = schedule_heuristic - self.max_running_seq = max_running_seq - self.max_prefill_num_token = max_prefill_num_token - self.max_total_num_token = max_total_num_token + self.max_running_seqs = max_running_seqs + self.max_prefill_num_tokens = max_prefill_num_tokens + self.max_total_num_tokens = max_total_num_tokens self.tree_cache = tree_cache def get_priority_queue(self, forward_queue): diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 061340ffa..416f0cc6f 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -24,7 +24,8 @@ class ServerArgs: # Memory and scheduling mem_fraction_static: Optional[float] = None - max_prefill_num_token: Optional[int] = None + max_prefill_tokens: Optional[int] = None + max_running_requests: Optional[int] = None schedule_heuristic: str = "lpm" schedule_conservativeness: float = 1.0 @@ -149,11 +150,17 @@ class ServerArgs: help="The fraction of the memory used for static allocation (model weights and KV cache memory pool). Use a smaller value if you see out-of-memory errors.", ) parser.add_argument( - "--max-prefill-num-token", + "--max-prefill-tokens", type=int, - default=ServerArgs.max_prefill_num_token, + default=ServerArgs.max_prefill_tokens, help="The maximum number of tokens in a prefill batch. The real bound will be the maximum of this value and the model's maximum context length.", ) + parser.add_argument( + "--max-running-requests", + type=int, + default=ServerArgs.max_running_requests, + help="The maximum number of running requests.", + ) parser.add_argument( "--schedule-heuristic", type=str, diff --git a/python/sglang/test/test_utils.py b/python/sglang/test/test_utils.py index b2aaeafaa..d73c59219 100644 --- a/python/sglang/test/test_utils.py +++ b/python/sglang/test/test_utils.py @@ -88,6 +88,28 @@ def call_generate_srt_raw(prompt, temperature, max_tokens, stop=None, url=None): return pred +def call_generate_xinfer(prompt, temperature, max_tokens, stop=None, url=None): + import grpc + from xlm.proto import sampler_pb2, sampler_pb2_grpc + + sampler_channel = grpc.insecure_channel(url.replace("http://", "")) + sampler = sampler_pb2_grpc.SamplerStub(sampler_channel) + + sample_request = sampler_pb2.SampleTextRequest( + prompt=prompt, + settings=sampler_pb2.SampleSettings( + max_len=max_tokens, + rng_seed=0, + temperature=max(temperature, 1e-7), + nucleus_p=1, + stop_strings=[stop], + ), + ) + stream = sampler.SampleText(sample_request) + response = "".join([x.text for x in stream]) + return response + + def call_generate_guidance( prompt, temperature, max_tokens, stop=None, n=1, regex=None, model=None ): @@ -228,6 +250,7 @@ def add_common_other_args_and_parse(parser): "vllm", "outlines", "lightllm", + "xinfer", "guidance", "lmql", "srt-raw", @@ -248,6 +271,7 @@ def add_common_other_args_and_parse(parser): "lightllm": 22000, "lmql": 23000, "srt-raw": 30000, + "xinfer": 9988, } args.port = default_port.get(args.backend, None) return args @@ -283,6 +307,8 @@ def _get_call_generate(args): return partial(call_generate_vllm, url=f"{args.host}:{args.port}/generate") elif args.backend == "srt-raw": return partial(call_generate_srt_raw, url=f"{args.host}:{args.port}/generate") + elif args.backend == "xinfer": + return partial(call_generate_xinfer, url=f"{args.host}:{args.port}") elif args.backend == "outlines": return partial(call_generate_outlines, url=f"{args.host}:{args.port}/generate") elif args.backend == "guidance":