From 7cd4f244a42178d0cdfb6a81156f38e87a7d92cd Mon Sep 17 00:00:00 2001 From: Liangsheng Yin Date: Mon, 29 Jul 2024 03:32:58 -0700 Subject: [PATCH] Chunked prefill (#800) --- .../managers/controller/schedule_heuristic.py | 20 +-- .../srt/managers/controller/tp_worker.py | 118 ++++++++++++++---- python/sglang/srt/memory_pool.py | 2 +- python/sglang/srt/server.py | 53 +++++--- python/sglang/srt/server_args.py | 21 +++- 5 files changed, 160 insertions(+), 54 deletions(-) diff --git a/python/sglang/srt/managers/controller/schedule_heuristic.py b/python/sglang/srt/managers/controller/schedule_heuristic.py index 88620bf99..d1f45836b 100644 --- a/python/sglang/srt/managers/controller/schedule_heuristic.py +++ b/python/sglang/srt/managers/controller/schedule_heuristic.py @@ -38,24 +38,24 @@ class ScheduleHeuristic: self.max_total_num_tokens = max_total_num_tokens self.tree_cache = tree_cache - def get_priority_queue(self, forward_queue): + def get_priority_queue(self, waiting_queue): if self.schedule_heuristic == "lpm": # longest prefix match - forward_queue.sort(key=lambda x: -len(x.prefix_indices)) - return forward_queue + waiting_queue.sort(key=lambda x: -len(x.prefix_indices)) + return waiting_queue elif self.schedule_heuristic == "fcfs": # first come first serve - return forward_queue + return waiting_queue elif self.schedule_heuristic == "lof": # longest output first - forward_queue.sort(key=lambda x: -x.sampling_params.max_new_tokens) - return forward_queue + waiting_queue.sort(key=lambda x: -x.sampling_params.max_new_tokens) + return waiting_queue elif self.schedule_heuristic == "random": - random.shuffle(forward_queue) - return forward_queue + random.shuffle(waiting_queue) + return waiting_queue elif self.schedule_heuristic == "dfs-weight": last_node_to_reqs = defaultdict(list) - for req in forward_queue: + for req in waiting_queue: last_node_to_reqs[req.last_node].append(req) node_to_weight = defaultdict(int) @@ -67,7 +67,7 @@ class ScheduleHeuristic: self.get_dfs_priority( self.tree_cache.root_node, node_to_weight, last_node_to_reqs, q ) - assert len(q) == len(forward_queue) + assert len(q) == len(waiting_queue) return q else: raise ValueError(f"Unknown schedule_heuristic: {self.schedule_heuristic}") diff --git a/python/sglang/srt/managers/controller/tp_worker.py b/python/sglang/srt/managers/controller/tp_worker.py index 94e535d14..abd933075 100644 --- a/python/sglang/srt/managers/controller/tp_worker.py +++ b/python/sglang/srt/managers/controller/tp_worker.py @@ -77,6 +77,10 @@ class ModelTpServer: self.schedule_heuristic = server_args.schedule_heuristic self.disable_regex_jump_forward = server_args.disable_regex_jump_forward + # Chunked prefill + self.chunked_prefill_size = server_args.chunked_prefill_size + self.current_inflight_req = None + # Init model and tokenizer self.model_config = ModelConfig( server_args.model_path, @@ -157,7 +161,7 @@ class ModelTpServer: self.token_to_kv_pool = self.model_runner.token_to_kv_pool # Init running status - self.forward_queue: List[Req] = [] + self.waiting_queue: List[Req] = [] self.running_batch: Batch = None self.out_pyobjs = [] self.decode_forward_ct = 0 @@ -220,6 +224,7 @@ class ModelTpServer: # Run a new prefill batch self.forward_prefill_batch(new_batch) self.cache_filled_batch(new_batch) + self.filter_out_inflight(new_batch) if not new_batch.is_empty(): if self.running_batch is None: @@ -261,7 +266,7 @@ class ModelTpServer: f"#token: {num_used}, " f"token usage: {num_used / self.max_total_num_tokens:.2f}, " f"gen throughput (token/s): {throughput:.2f}, " - f"#queue-req: {len(self.forward_queue)}" + f"#queue-req: {len(self.waiting_queue)}" ) def check_memory(self): @@ -328,9 +333,10 @@ class ModelTpServer: ), self.max_req_input_len - 1 - len(req.origin_input_ids), ) - self.forward_queue.append(req) + self.waiting_queue.append(req) def get_new_prefill_batch(self) -> Optional[Batch]: + # TODO(lsyin): organize this function running_bs = ( len(self.running_batch.reqs) if self.running_batch is not None else 0 ) @@ -338,7 +344,7 @@ class ModelTpServer: return # Compute matched prefix length - for req in self.forward_queue: + for req in self.waiting_queue: req.input_ids = req.origin_input_ids + req.output_ids prefix_indices, last_node = self.tree_cache.match_prefix(req.input_ids) if req.return_logprob: @@ -348,7 +354,7 @@ class ModelTpServer: req.last_node = last_node # Get priority queue - self.forward_queue = self.scheduler.get_priority_queue(self.forward_queue) + self.waiting_queue = self.scheduler.get_priority_queue(self.waiting_queue) # Add requests if there is available space can_run_list = [] @@ -367,7 +373,33 @@ class ModelTpServer: ] ) - for req in self.forward_queue: + # Handle the current inflight request + take_inflight = 0 + if self.current_inflight_req: + take_inflight = 1 + r = self.current_inflight_req + r.input_ids = r.origin_input_ids + r.output_ids + truncated = ( + len(r.input_ids) - len(r.prefix_indices) > self.chunked_prefill_size + ) + r.extend_input_len = min( + len(r.input_ids) - len(r.prefix_indices), self.chunked_prefill_size + ) + r.input_ids = r.input_ids[: len(r.prefix_indices) + r.extend_input_len] + can_run_list.append(r) + + if not truncated: + # Finish inflight + self.current_inflight_req = None + new_batch_total_tokens += ( + r.extend_input_len + r.sampling_params.max_new_tokens + ) + new_batch_input_tokens += r.extend_input_len + else: + new_batch_total_tokens += r.extend_input_len + new_batch_input_tokens += r.extend_input_len + + for req in self.waiting_queue: if req.return_logprob and req.normalized_prompt_logprob is None: # Need at least two tokens to compute normalized logprob if req.extend_input_len < 2: @@ -409,11 +441,36 @@ class ModelTpServer: break else: # Add this request to the running batch - can_run_list.append(req) - new_batch_total_tokens += ( - req.extend_input_len + req.sampling_params.max_new_tokens - ) - new_batch_input_tokens += req.extend_input_len + if ( + new_batch_input_tokens + req.extend_input_len + <= self.chunked_prefill_size + or ( + req.return_logprob and req.normalized_prompt_logprob is None + ) + ): + can_run_list.append(req) + new_batch_total_tokens += ( + req.extend_input_len + req.sampling_params.max_new_tokens + ) + new_batch_input_tokens += req.extend_input_len + else: + trunc_len = self.chunked_prefill_size - new_batch_input_tokens + + if trunc_len <= 0: + # Undo locking + delta = self.tree_cache.dec_lock_ref(req.last_node) + available_size += delta + break + + req.extend_input_len = trunc_len + req.input_ids = req.input_ids[ + : len(req.prefix_indices) + req.extend_input_len + ] + can_run_list.append(req) + self.current_inflight_req = req + new_batch_input_tokens += req.extend_input_len + new_batch_total_tokens += req.extend_input_len + break else: break @@ -440,7 +497,7 @@ class ModelTpServer: f"#cached-token: {hit_tokens}, " f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, " f"#running-req: {running_bs}, " - f"#queue-req: {len(self.forward_queue) - len(can_run_list)}" + f"#queue-req: {len(self.waiting_queue) - len(can_run_list) + take_inflight}" ) # Return the new batch @@ -450,7 +507,7 @@ class ModelTpServer: self.token_to_kv_pool, self.tree_cache, ) - self.forward_queue = [x for x in self.forward_queue if x not in can_run_list] + self.waiting_queue = [x for x in self.waiting_queue if x not in can_run_list] return new_batch def forward_prefill_batch(self, batch: Batch): @@ -482,9 +539,10 @@ class ModelTpServer: # Check finish conditions pt = 0 for i, req in enumerate(batch.reqs): - req.completion_tokens_wo_jump_forward += 1 - req.output_ids.append(next_token_ids[i]) - req.check_finished() + if req is not self.current_inflight_req: + req.completion_tokens_wo_jump_forward += 1 + req.output_ids.append(next_token_ids[i]) + req.check_finished() if req.return_logprob: self.add_logprob_return_values(i, req, pt, next_token_ids, output) @@ -545,7 +603,7 @@ class ModelTpServer: req_pool_indices_cpu = batch.req_pool_indices.cpu().numpy() for i, req in enumerate(batch.reqs): new_prefix_indices, new_last_node = self.tree_cache.cache_req( - token_ids=tuple(req.origin_input_ids + req.output_ids)[:-1], + token_ids=tuple(req.input_ids), last_uncached_pos=len(req.prefix_indices), req_pool_idx=req_pool_indices_cpu[i], del_in_memory_pool=False, @@ -553,6 +611,10 @@ class ModelTpServer: ) req.prefix_indices, req.last_node = new_prefix_indices, new_last_node + if req is self.current_inflight_req: + # inflight request would get a new req idx + self.req_to_token_pool.free(int(req_pool_indices_cpu[i])) + def forward_decode_batch(self, batch: Batch): # Check if decode out of memory if not batch.check_decode_mem(): @@ -566,7 +628,7 @@ class ModelTpServer: f"#retracted_reqs: {len(retracted_reqs)}, " f"#new_token_ratio: {old_ratio:.4f} -> {self.new_token_ratio:.4f}" ) - self.forward_queue.extend(retracted_reqs) + self.waiting_queue.extend(retracted_reqs) else: self.new_token_ratio = max( self.new_token_ratio - self.new_token_ratio_decay, @@ -576,7 +638,7 @@ class ModelTpServer: if not self.disable_regex_jump_forward: # Check for jump-forward jump_forward_reqs = batch.check_for_jump_forward(self.model_runner) - self.forward_queue.extend(jump_forward_reqs) + self.waiting_queue.extend(jump_forward_reqs) if batch.is_empty(): return @@ -711,8 +773,18 @@ class ModelTpServer: else: batch.reqs = [] + def filter_out_inflight(self, batch: Batch): + # TODO(lsyin): reduce the overhead, make a special version for this + if self.current_inflight_req is None: + return + + to_remove = batch.reqs.index(self.current_inflight_req) + unfinished_indices = [i for i in range(len(batch.reqs)) if i != to_remove] + + batch.filter_batch(unfinished_indices) + def flush_cache(self): - if len(self.forward_queue) == 0 and ( + if len(self.waiting_queue) == 0 and ( self.running_batch is None or len(self.running_batch.reqs) == 0 ): self.tree_cache.reset() @@ -725,20 +797,20 @@ class ModelTpServer: else: warnings.warn( f"Cache not flushed because there are pending requests. " - f"#queue-req: {len(self.forward_queue)}, " + f"#queue-req: {len(self.waiting_queue)}, " f"#running-req: {0 if self.running_batch is None else len(self.running_batch.reqs)}" ) def abort_request(self, recv_req): # Delete requests in the waiting queue to_del = None - for i, req in enumerate(self.forward_queue): + for i, req in enumerate(self.waiting_queue): if req.rid == recv_req.rid: to_del = i break if to_del is not None: - del self.forward_queue[to_del] + del self.waiting_queue[to_del] # Delete requests in the running batch if self.running_batch: diff --git a/python/sglang/srt/memory_pool.py b/python/sglang/srt/memory_pool.py index d4948b525..fa38ee41c 100644 --- a/python/sglang/srt/memory_pool.py +++ b/python/sglang/srt/memory_pool.py @@ -45,7 +45,7 @@ class ReqToTokenPool: return select_index - def free(self, free_index: int): + def free(self, free_index): self.mem_state[free_index] = True if isinstance(free_index, (int,)): self.can_use_mem_size += 1 diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py index e4801deeb..41d71545a 100644 --- a/python/sglang/srt/server.py +++ b/python/sglang/srt/server.py @@ -175,6 +175,39 @@ def _set_torch_compile_config(): torch._dynamo.config.accumulated_cache_size_limit = 256 +def set_envs_and_config(server_args: ServerArgs): + # Set global environments + os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" + os.environ["NCCL_CUMEM_ENABLE"] = "0" + os.environ["NCCL_NVLS_ENABLE"] = "0" + os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1" + + # Set ulimit + set_ulimit() + + # Enable show time cost for debugging + if server_args.show_time_cost: + enable_show_time_cost() + + # Disable disk cache + if server_args.disable_disk_cache: + disable_cache() + + # Fix triton bugs + if server_args.tp_size * server_args.dp_size > 1: + # FIXME: remove this after https://github.com/triton-lang/triton/pull/4295 is used as a dependency. + maybe_set_triton_cache_manager() + + # Set torch compile config + if server_args.enable_torch_compile: + _set_torch_compile_config() + + # Set global chat template + if server_args.chat_template: + # TODO: replace this with huggingface transformers template + load_chat_template_for_openai_api(server_args.chat_template) + + def launch_server( server_args: ServerArgs, model_overide_args: Optional[dict] = None, @@ -190,16 +223,6 @@ def launch_server( format="%(message)s", ) - # Set global environments - os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" - os.environ["NCCL_CUMEM_ENABLE"] = "0" - os.environ["NCCL_NVLS_ENABLE"] = "0" - os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1" - set_ulimit() - if server_args.show_time_cost: - enable_show_time_cost() - if server_args.disable_disk_cache: - disable_cache() if not server_args.disable_flashinfer: assert_pkg_version( "flashinfer", @@ -208,14 +231,8 @@ def launch_server( "reinstall the latest version by following the instructions " "at https://docs.flashinfer.ai/installation.html.", ) - if server_args.tp_size * server_args.dp_size > 1: - # FIXME: remove this after https://github.com/triton-lang/triton/pull/4295 is used as a dependency. - maybe_set_triton_cache_manager() - if server_args.chat_template: - # TODO: replace this with huggingface transformers template - load_chat_template_for_openai_api(server_args.chat_template) - if server_args.enable_torch_compile: - _set_torch_compile_config() + + set_envs_and_config(server_args) # Allocate ports server_args.port, server_args.additional_ports = allocate_init_ports( diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 5c22c48e6..c9535f402 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -65,6 +65,9 @@ class ServerArgs: dp_size: int = 1 load_balance_method: str = "round_robin" + # Chunked Prefill + chunked_prefill_size: Optional[int] = None + # Optimization/debug options disable_flashinfer: bool = False disable_flashinfer_sampling: bool = False @@ -83,6 +86,8 @@ class ServerArgs: node_rank: Optional[int] = None def __post_init__(self): + if self.chunked_prefill_size is None: + self.chunked_prefill_size = 1 << 30 if self.tokenizer_path is None: self.tokenizer_path = self.model_path if self.mem_fraction_static is None: @@ -223,7 +228,7 @@ class ServerArgs: parser.add_argument( "--max-num-reqs", type=int, - default=None, + default=ServerArgs.max_num_reqs, help="The maximum number of requests to serve in the memory pool. If the model have a large context length, you may need to decrease this value to avoid out-of-memory errors.", ) parser.add_argument( @@ -311,10 +316,18 @@ class ServerArgs: help="The nccl init address of multi-node server.", ) parser.add_argument( - "--nnodes", type=int, default=1, help="The number of nodes." + "--nnodes", type=int, default=ServerArgs.nnodes, help="The number of nodes." ) parser.add_argument("--node-rank", type=int, help="The node rank.") + # Chunked prefill + parser.add_argument( + "--chunked-prefill-size", + type=int, + default=ServerArgs.chunked_prefill_size, + help="The size of the chunked prefill.", + ) + # Optimization/debug options parser.add_argument( "--disable-flashinfer", @@ -393,6 +406,10 @@ class ServerArgs: self.dp_size > 1 and self.node_rank is not None ), "multi-node data parallel is not supported" + assert not ( + self.chunked_prefill_size < (1 << 30) and self.disable_radix_cache + ), "chunked prefill is not supported with radix cache disabled currently" + @dataclasses.dataclass class PortArgs: