From 665815969a71a478b840999cb821054814a723fc Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Sat, 13 Jul 2024 05:29:46 -0700 Subject: [PATCH] Enable cuda graph by default (#612) --- benchmark/latency_throughput/bench_one.py | 82 ++++++--- python/sglang/bench_latency.py | 1 - python/sglang/global_config.py | 40 ++-- .../managers/controller/cuda_graph_runner.py | 173 ++++++++++++++++++ .../srt/managers/controller/infer_batch.py | 19 +- .../srt/managers/controller/model_runner.py | 67 ++++--- .../srt/managers/controller/tp_worker.py | 14 +- python/sglang/srt/memory_pool.py | 10 +- python/sglang/srt/server.py | 1 + python/sglang/srt/server_args.py | 8 +- 10 files changed, 331 insertions(+), 84 deletions(-) create mode 100644 python/sglang/srt/managers/controller/cuda_graph_runner.py diff --git a/benchmark/latency_throughput/bench_one.py b/benchmark/latency_throughput/bench_one.py index 5b2fa3cbc..cb3ec5a4e 100644 --- a/benchmark/latency_throughput/bench_one.py +++ b/benchmark/latency_throughput/bench_one.py @@ -1,45 +1,43 @@ +""" +Usage: +python3 bench_one.py --input-len 2048 --batch-size 1 2 4 8 16 32 64 128 256 512 +""" + import argparse +import json import time +import numpy as np import requests -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument("--host", type=str, default="http://127.0.0.1") - parser.add_argument("--port", type=int, default=None) - parser.add_argument("--backend", type=str, default="srt") - parser.add_argument("--batch-size", type=int, default=1) - parser.add_argument("--max-tokens", type=int, default=256) - args = parser.parse_args() - - if args.port is None: - if args.backend == "srt": - args.port = 30000 - elif args.backend == "vllm": - args.port = 21000 - elif args.backend == "lightllm": - args.port = 22000 - elif args.backend == "ginfer": - args.port = 9988 - else: - raise ValueError(f"Invalid backend: {args.backend}") +def run_one_batch_size(bs): url = f"{args.host}:{args.port}" - a = 20 max_new_tokens = args.max_tokens + + a = 20 prompt = f"{a, }" tic = time.time() if args.backend == "srt": + if args.input_len: + inputs = {"input_ids": [ + [int(x) for x in np.random.randint(0, high=16384, size=(args.input_len,))] for _ in range(bs) + ]} + else: + inputs = {"text": [ + f"{i, }" for i in range(bs) + ]} + response = requests.post( url + "/generate", json={ - "text": [prompt] * args.batch_size, "sampling_params": { "temperature": 0, "max_new_tokens": max_new_tokens, "ignore_eos": True, }, + **inputs, }, ) elif args.backend == "lightllm": @@ -91,5 +89,41 @@ if __name__ == "__main__": ret = response.json() print(ret) - speed = args.batch_size * max_new_tokens / latency - print(f"latency: {latency:.2f} s, speed: {speed:.2f} token/s") + output_throughput = bs * max_new_tokens / latency + print(f"latency: {latency:.2f} s, speed: {output_throughput:.2f} token/s") + + with open("tmp_output.txt", "a") as fout: + res = { + "input_len": args.input_len, + "output_len": args.max_tokens, + "batch_size": bs, + "latency": latency, + "output_throughput": output_throughput + } + fout.write(json.dumps(res) + "\n") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--host", type=str, default="http://127.0.0.1") + parser.add_argument("--port", type=int, default=None) + parser.add_argument("--backend", type=str, default="srt") + parser.add_argument("--input-len", type=int, default=None) + parser.add_argument("--batch-size", type=int, nargs='*', default=[1]) + parser.add_argument("--max-tokens", type=int, default=256) + args = parser.parse_args() + + if args.port is None: + if args.backend == "srt": + args.port = 30000 + elif args.backend == "vllm": + args.port = 21000 + elif args.backend == "lightllm": + args.port = 22000 + elif args.backend == "ginfer": + args.port = 9988 + else: + raise ValueError(f"Invalid backend: {args.backend}") + + for bs in args.batch_size: + run_one_batch_size(bs) diff --git a/python/sglang/bench_latency.py b/python/sglang/bench_latency.py index 49727b121..23ec11a34 100644 --- a/python/sglang/bench_latency.py +++ b/python/sglang/bench_latency.py @@ -30,7 +30,6 @@ import argparse import dataclasses import logging import multiprocessing -import os import time diff --git a/python/sglang/global_config.py b/python/sglang/global_config.py index 00340b59a..662cb4a6f 100644 --- a/python/sglang/global_config.py +++ b/python/sglang/global_config.py @@ -8,36 +8,40 @@ class GlobalConfig: # 2: output final text after every run self.verbosity = 0 + # Default backend of the language self.default_backend = None - # Output configs - self.skip_special_tokens_in_output = True - self.spaces_between_special_tokens_in_out = True - - # Optimization configs - self.eager_fill_image = False - self.enable_precache_with_tracing = True - self.enable_parallel_encoding = True - self.enable_parallel_decoding = True - - # Choices: ["no_adjust", "adjust_cache"] - # no_adjust: Do not adjust the position embedding of KV cache. - # adjust_cache: Adjust the position embedding of KV cache. - self.concate_and_append_mode = "no_adjust" - - # Request dependency time due to network delay + # Runtime constants: Request dependency time due to network delay self.request_dependency_delay = 0.02 self.wait_for_new_request_delay = 0.0006 - # New generation token ratio estimation + # Runtime constants: New generation token ratio estimation self.base_new_token_ratio = 0.4 self.base_min_new_token_ratio = 0.2 self.new_token_ratio_decay = 0.0001 self.new_token_ratio_recovery = 0.05 - # The threshold (number of tokens) to trigger layer-wise cuda sync. + # Runtime constants: The threshold (number of tokens) to trigger layer-wise cuda sync. # This can improve the speed for large batch sizes during prefill. self.layer_sync_threshold = 8192 + # Runtime constants: Flashinfer + self.flashinfer_workspace_size = 192 * 1024 * 1024 + + # Output tokenization configs + self.skip_special_tokens_in_output = True + self.spaces_between_special_tokens_in_out = True + + # Interpreter optimization configs + self.eager_fill_image = False + self.enable_precache_with_tracing = True + self.enable_parallel_encoding = True + self.enable_parallel_decoding = True + + # Deprecated + # Choices: ["no_adjust", "adjust_cache"] + # no_adjust: Do not adjust the position embedding of KV cache. + # adjust_cache: Adjust the position embedding of KV cache. + self.concate_and_append_mode = "no_adjust" global_config = GlobalConfig() diff --git a/python/sglang/srt/managers/controller/cuda_graph_runner.py b/python/sglang/srt/managers/controller/cuda_graph_runner.py new file mode 100644 index 000000000..eee1bb81f --- /dev/null +++ b/python/sglang/srt/managers/controller/cuda_graph_runner.py @@ -0,0 +1,173 @@ +"""Run the model with cuda graph.""" + +import bisect + +import torch +from vllm.distributed.parallel_state import graph_capture + +from sglang.global_config import global_config +from sglang.srt.layers.logits_processor import LogitProcessorOutput +from sglang.srt.managers.controller.infer_batch import ( + Batch, ForwardMode, InputMetadata, init_flashinfer_args +) + + +class CudaGraphRunner: + def __init__(self, model_runner, max_batch_size_to_capture): + self.model_runner = model_runner + self.graphs = {} + self.input_buffers = {} + self.output_buffers = {} + self.flashinfer_handlers = {} + self.graph_memory_pool = None + + # Common inputs + self.max_bs = max_batch_size_to_capture + self.input_ids = torch.zeros((self.max_bs,), dtype=torch.int32, device="cuda") + self.req_pool_indices = torch.zeros((self.max_bs,), dtype=torch.int32, device="cuda") + self.seq_lens = torch.ones((self.max_bs,), dtype=torch.int32, device="cuda") + self.position_ids_offsets = torch.zeros((self.max_bs,), dtype=torch.int32, device="cuda") + self.out_cache_loc = torch.zeros((self.max_bs,), dtype=torch.int32, device="cuda") + + # FlashInfer inputs + self.flashinfer_workspace_buffer = self.model_runner.flashinfer_workspace_buffers[0] + self.flashinfer_kv_indptr = torch.zeros( + (self.max_bs + 1,), dtype=torch.int32, device="cuda" + ) + self.flashinfer_kv_indices = torch.zeros( + (self.max_bs * model_runner.model_config.context_len,), dtype=torch.int32, device="cuda" + ) + self.flashinfer_kv_last_page_len = torch.ones( + (self.max_bs,), dtype=torch.int32, device="cuda" + ) + + def can_run(self, batch_size): + return batch_size < self.max_bs + + def capture(self, batch_size_list): + self.batch_size_list = batch_size_list + with graph_capture() as graph_capture_context: + self.stream = graph_capture_context.stream + for bs in batch_size_list: + graph, input_buffers, output_buffers, flashinfer_handler = self.capture_one_batch_size(bs) + self.graphs[bs] = graph + self.input_buffers[bs] = input_buffers + self.output_buffers[bs] = output_buffers + self.flashinfer_handlers[bs] = flashinfer_handler + + def capture_one_batch_size(self, bs): + from flashinfer import BatchDecodeWithPagedKVCacheWrapper + from flashinfer.decode import _grouped_size_compiled_for_decode_kernels + + graph = torch.cuda.CUDAGraph() + stream = self.stream + + # Common inputs + input_ids = self.input_ids[:bs] + req_pool_indices = self.req_pool_indices[:bs] + seq_lens = self.seq_lens[:bs] + position_ids_offsets = self.position_ids_offsets[:bs] + out_cache_loc = self.out_cache_loc[:bs] + + # FlashInfer inputs + if not _grouped_size_compiled_for_decode_kernels( + self.model_runner.model_config.num_attention_heads // self.model_runner.tp_size, + self.model_runner.model_config.get_num_kv_heads(self.model_runner.tp_size), + ): + use_tensor_cores = True + else: + use_tensor_cores = False + flashinfer_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper( + self.flashinfer_workspace_buffer, "NHD", + use_cuda_graph=True, + use_tensor_cores=use_tensor_cores, + paged_kv_indptr_buffer=self.flashinfer_kv_indptr[:bs+1], + paged_kv_indices_buffer=self.flashinfer_kv_indices, + paged_kv_last_page_len_buffer=self.flashinfer_kv_last_page_len[:bs], + ) + init_flashinfer_args( + ForwardMode.DECODE, + self.model_runner, + req_pool_indices, + seq_lens, + None, + flashinfer_decode_wrapper, + ) + + # Run and capture + def run_once(): + input_metadata = InputMetadata.create( + self.model_runner, + forward_mode=ForwardMode.DECODE, + req_pool_indices=req_pool_indices, + seq_lens=seq_lens, + prefix_lens=None, + position_ids_offsets=position_ids_offsets, + out_cache_loc=out_cache_loc, + out_cache_cont_start=None, + out_cache_cont_end=None, + return_logprob=False, + top_logprobs_nums=0, + skip_flashinfer_init=True, + ) + input_metadata.flashinfer_decode_wrapper = flashinfer_decode_wrapper + return self.model_runner.model.forward( + input_ids, input_metadata.positions, input_metadata + ) + + for _ in range(2): + run_once() + + torch.cuda.synchronize() + with torch.cuda.graph(graph, pool=self.graph_memory_pool, stream=stream): + out = run_once() + torch.cuda.synchronize() + self.graph_memory_pool = graph.pool() + return graph, None, out, flashinfer_decode_wrapper + + def replay(self, batch: Batch): + assert batch.out_cache_loc is not None + assert not batch.return_logprob + raw_bs = len(batch.reqs) + + # Pad + index = bisect.bisect_left(self.batch_size_list, raw_bs) + bs = self.batch_size_list[index] + if bs != raw_bs: + self.seq_lens.fill_(1) + self.out_cache_loc.zero_() + + # Common inputs + self.input_ids[:raw_bs] = batch.input_ids + self.req_pool_indices[:raw_bs] = batch.req_pool_indices + self.seq_lens[:raw_bs] = batch.seq_lens + self.position_ids_offsets[:raw_bs] = batch.position_ids_offsets + self.out_cache_loc[:raw_bs] = batch.out_cache_loc + + # FlashInfer inputs + init_flashinfer_args( + ForwardMode.DECODE, + self.model_runner, + self.req_pool_indices[:bs], + self.seq_lens[:bs], + None, + self.flashinfer_handlers[bs], + ) + + # Replay + self.graphs[bs].replay() + output = self.output_buffers[bs] + + # Unpad + if bs == raw_bs: + return output + else: + output = LogitProcessorOutput( + next_token_logits=output.next_token_logits[:raw_bs], + next_token_logprobs=output.next_token_logprobs[:raw_bs] if output.next_token_logprobs is not None else None, + normalized_prompt_logprobs=None, + prefill_token_logprobs=None, + prefill_top_logprobs=None, + decode_top_logprobs=output.decode_top_logprobs[:raw_bs] if output.decode_top_logprobs is not None else None, + ) + return output \ No newline at end of file diff --git a/python/sglang/srt/managers/controller/infer_batch.py b/python/sglang/srt/managers/controller/infer_batch.py index b7e05a5dc..d1bc60f9d 100644 --- a/python/sglang/srt/managers/controller/infer_batch.py +++ b/python/sglang/srt/managers/controller/infer_batch.py @@ -675,7 +675,11 @@ class Batch: # TODO(lmzheng): apply penalty probs = torch.softmax(logits, dim=-1) probs_sort, probs_idx = _top_p_top_k(probs, self.top_ps, self.top_ks) - sampled_index = torch.multinomial(probs_sort, num_samples=1) + try: + sampled_index = torch.multinomial(probs_sort, num_samples=1) + except RuntimeError as e: + warnings.warn(f"Ignore errors in sampling: {e}") + sampled_index = torch.ones(probs_sort.shape[:-1] + (1,), dtype=torch.int64, device=probs.device) batch_next_token_ids = torch.gather(probs_idx, dim=1, index=sampled_index).view( -1 ) @@ -757,9 +761,11 @@ class InputMetadata: out_cache_cont_end=None, top_logprobs_nums=None, return_logprob=False, + skip_flashinfer_init=False, ): - if not model_runner.server_args.disable_flashinfer: - init_flashinfer_args(forward_mode, model_runner, req_pool_indices, seq_lens, prefix_lens) + if not skip_flashinfer_init and not model_runner.server_args.disable_flashinfer: + init_flashinfer_args(forward_mode, model_runner, req_pool_indices, seq_lens, prefix_lens, + model_runner.flashinfer_decode_wrapper) batch_size = len(req_pool_indices) @@ -826,7 +832,8 @@ class InputMetadata: return ret -def init_flashinfer_args(forward_mode, model_runner, req_pool_indices, seq_lens, prefix_lens): +def init_flashinfer_args(forward_mode, model_runner, req_pool_indices, seq_lens, prefix_lens, + flashinfer_decode_wrapper): num_qo_heads = model_runner.model_config.num_attention_heads // model_runner.tp_size num_kv_heads = model_runner.model_config.get_num_kv_heads(model_runner.tp_size) head_dim = model_runner.model_config.head_dim @@ -857,8 +864,8 @@ def init_flashinfer_args(forward_mode, model_runner, req_pool_indices, seq_lens, ) if forward_mode == ForwardMode.DECODE: - model_runner.flashinfer_decode_wrapper.end_forward() - model_runner.flashinfer_decode_wrapper.begin_forward( + flashinfer_decode_wrapper.end_forward() + flashinfer_decode_wrapper.begin_forward( kv_indptr, kv_indices, kv_last_page_len, diff --git a/python/sglang/srt/managers/controller/model_runner.py b/python/sglang/srt/managers/controller/model_runner.py index 5945282ab..d7b66e76b 100644 --- a/python/sglang/srt/managers/controller/model_runner.py +++ b/python/sglang/srt/managers/controller/model_runner.py @@ -15,6 +15,7 @@ from vllm.distributed import init_distributed_environment, initialize_model_para from vllm.model_executor.model_loader import get_model from vllm.model_executor.models import ModelRegistry +from sglang.global_config import global_config from sglang.srt.managers.controller.infer_batch import Batch, ForwardMode, InputMetadata, global_server_args_dict from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool from sglang.srt.server_args import ServerArgs @@ -90,6 +91,9 @@ class ModelRunner: self.init_cublas() self.init_flash_infer() + # Capture cuda graphs + self.init_cuda_graphs() + def load_model(self): logger.info( f"[gpu_id={self.gpu_id}] Load weight begin. " @@ -203,17 +207,51 @@ class ModelRunner: else: use_tensor_cores = False - workspace_buffers = torch.empty( - 3, 96 * 1024 * 1024, dtype=torch.uint8, device="cuda" + self.flashinfer_workspace_buffers = torch.empty( + 2, global_config.flashinfer_workspace_size, dtype=torch.uint8, device="cuda" ) self.flashinfer_prefill_wrapper_ragged = BatchPrefillWithRaggedKVCacheWrapper( - workspace_buffers[0], "NHD" + self.flashinfer_workspace_buffers[0], "NHD" ) self.flashinfer_prefill_wrapper_paged = BatchPrefillWithPagedKVCacheWrapper( - workspace_buffers[1], "NHD" + self.flashinfer_workspace_buffers[1], "NHD" ) self.flashinfer_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper( - workspace_buffers[2], "NHD", use_tensor_cores=use_tensor_cores + self.flashinfer_workspace_buffers[0], "NHD", use_tensor_cores=use_tensor_cores + ) + + def init_cuda_graphs(self): + from sglang.srt.managers.controller.cuda_graph_runner import CudaGraphRunner + + if self.server_args.disable_cuda_graph or self.server_args.disable_flashinfer: + self.cuda_graph_runner = None + return + + logger.info(f"[gpu_id={self.gpu_id}] Capture cuda graph begin.") + batch_size_list = [1, 2, 4] + [i * 8 for i in range(1, 16)] + self.cuda_graph_runner = CudaGraphRunner(self, max_batch_size_to_capture=max(batch_size_list)) + self.cuda_graph_runner.capture(batch_size_list) + + @torch.inference_mode() + def forward_decode(self, batch: Batch): + if self.cuda_graph_runner and self.cuda_graph_runner.can_run(len(batch.reqs)): + return self.cuda_graph_runner.replay(batch) + + input_metadata = InputMetadata.create( + self, + forward_mode=ForwardMode.DECODE, + req_pool_indices=batch.req_pool_indices, + seq_lens=batch.seq_lens, + prefix_lens=batch.prefix_lens, + position_ids_offsets=batch.position_ids_offsets, + out_cache_loc=batch.out_cache_loc, + out_cache_cont_start=batch.out_cache_cont_start, + out_cache_cont_end=batch.out_cache_cont_end, + top_logprobs_nums=batch.top_logprobs_nums, + return_logprob=batch.return_logprob, + ) + return self.model.forward( + batch.input_ids, input_metadata.positions, input_metadata ) @torch.inference_mode() @@ -233,25 +271,6 @@ class ModelRunner: batch.input_ids, input_metadata.positions, input_metadata ) - @torch.inference_mode() - def forward_decode(self, batch: Batch): - input_metadata = InputMetadata.create( - self, - forward_mode=ForwardMode.DECODE, - req_pool_indices=batch.req_pool_indices, - seq_lens=batch.seq_lens, - prefix_lens=batch.prefix_lens, - position_ids_offsets=batch.position_ids_offsets, - out_cache_loc=batch.out_cache_loc, - out_cache_cont_start=batch.out_cache_cont_start, - out_cache_cont_end=batch.out_cache_cont_end, - top_logprobs_nums=batch.top_logprobs_nums, - return_logprob=batch.return_logprob, - ) - return self.model.forward( - batch.input_ids, input_metadata.positions, input_metadata - ) - @torch.inference_mode() def forward_extend_multi_modal(self, batch: Batch): input_metadata = InputMetadata.create( diff --git a/python/sglang/srt/managers/controller/tp_worker.py b/python/sglang/srt/managers/controller/tp_worker.py index 6d92c6bab..89cd851a0 100644 --- a/python/sglang/srt/managers/controller/tp_worker.py +++ b/python/sglang/srt/managers/controller/tp_worker.py @@ -98,7 +98,7 @@ class ModelTpServer: ) self.max_total_num_tokens = self.model_runner.max_total_num_tokens self.max_prefill_tokens = ( - 4096 + 8192 if server_args.max_prefill_tokens is None else server_args.max_prefill_tokens ) @@ -314,11 +314,9 @@ class ModelTpServer: self.forward_queue.append(req) def get_new_fill_batch(self) -> Optional[Batch]: - if ( - self.running_batch is not None - and len(self.running_batch.reqs) > self.max_running_requests - ): - return None + running_bs = len(self.running_batch.reqs) if self.running_batch is not None else 0 + if running_bs > self.max_running_requests: + return # Compute matched prefix length for req in self.forward_queue: @@ -394,6 +392,10 @@ class ModelTpServer: new_batch_input_tokens += req.extend_input_len else: break + + if running_bs + len(can_run_list) > self.max_running_requests: + break + if len(can_run_list) == 0: return None diff --git a/python/sglang/srt/memory_pool.py b/python/sglang/srt/memory_pool.py index 33f4b8784..51b9beeb2 100644 --- a/python/sglang/srt/memory_pool.py +++ b/python/sglang/srt/memory_pool.py @@ -38,7 +38,10 @@ class ReqToTokenPool: class TokenToKVPool: def __init__(self, size, dtype, head_num, head_dim, layer_num): - self.mem_state = torch.zeros((size,), dtype=torch.int16, device="cuda") + self.size = size + # mem_state is the reference counter. + # We also add one slot. This slot is used for writing dummy output from padded tokens. + self.mem_state = torch.zeros((self.size + 1,), dtype=torch.int16, device="cuda") self.total_ref_ct = 0 # [size, key/value, head_num, head_dim] for each layer @@ -47,6 +50,8 @@ class TokenToKVPool: for _ in range(layer_num) ] + self.clear() + def get_key_buffer(self, layer_id): return self.kv_data[layer_id][:, 0] @@ -101,3 +106,6 @@ class TokenToKVPool: def clear(self): self.mem_state.fill_(0) self.total_ref_ct = 0 + + # We also add one slot. This slot is used for writing dummy output from padded tokens. + self.add_refs(torch.tensor([0], dtype=torch.int32)) \ No newline at end of file diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py index eb37c7bb5..6cda67dea 100644 --- a/python/sglang/srt/server.py +++ b/python/sglang/srt/server.py @@ -146,6 +146,7 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg # Set global environments os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" + os.environ["NCCL_CUMEM_ENABLE"] = "0" if server_args.show_time_cost: enable_show_time_cost() if server_args.disable_disk_cache: diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index ef8b6d252..46dfc25d2 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -29,7 +29,7 @@ class ServerArgs: max_prefill_tokens: Optional[int] = None max_running_requests: Optional[int] = None schedule_heuristic: str = "lpm" - schedule_conservativeness: float = 1.0 + schedule_conservativeness: float = 0.8 # Other runtime options tp_size: int = 1 @@ -68,13 +68,13 @@ class ServerArgs: self.tokenizer_path = self.model_path if self.mem_fraction_static is None: if self.tp_size >= 8: - self.mem_fraction_static = 0.80 + self.mem_fraction_static = 0.78 elif self.tp_size >= 4: - self.mem_fraction_static = 0.82 + self.mem_fraction_static = 0.80 elif self.tp_size >= 2: self.mem_fraction_static = 0.85 else: - self.mem_fraction_static = 0.90 + self.mem_fraction_static = 0.88 if isinstance(self.additional_ports, int): self.additional_ports = [self.additional_ports] elif self.additional_ports is None: