From 98111fbe3ebd429258923ae00c3e1c7b1be8dcec Mon Sep 17 00:00:00 2001 From: Ying Sheng Date: Mon, 29 Jul 2024 02:38:31 -0700 Subject: [PATCH] Revert "Chunked prefill support" (#799) --- .../managers/controller/schedule_heuristic.py | 20 +-- .../srt/managers/controller/tp_worker.py | 118 ++++-------------- python/sglang/srt/memory_pool.py | 2 +- python/sglang/srt/server.py | 53 +++----- python/sglang/srt/server_args.py | 21 +--- 5 files changed, 54 insertions(+), 160 deletions(-) diff --git a/python/sglang/srt/managers/controller/schedule_heuristic.py b/python/sglang/srt/managers/controller/schedule_heuristic.py index d1f45836b..88620bf99 100644 --- a/python/sglang/srt/managers/controller/schedule_heuristic.py +++ b/python/sglang/srt/managers/controller/schedule_heuristic.py @@ -38,24 +38,24 @@ class ScheduleHeuristic: self.max_total_num_tokens = max_total_num_tokens self.tree_cache = tree_cache - def get_priority_queue(self, waiting_queue): + def get_priority_queue(self, forward_queue): if self.schedule_heuristic == "lpm": # longest prefix match - waiting_queue.sort(key=lambda x: -len(x.prefix_indices)) - return waiting_queue + forward_queue.sort(key=lambda x: -len(x.prefix_indices)) + return forward_queue elif self.schedule_heuristic == "fcfs": # first come first serve - return waiting_queue + return forward_queue elif self.schedule_heuristic == "lof": # longest output first - waiting_queue.sort(key=lambda x: -x.sampling_params.max_new_tokens) - return waiting_queue + forward_queue.sort(key=lambda x: -x.sampling_params.max_new_tokens) + return forward_queue elif self.schedule_heuristic == "random": - random.shuffle(waiting_queue) - return waiting_queue + random.shuffle(forward_queue) + return forward_queue elif self.schedule_heuristic == "dfs-weight": last_node_to_reqs = defaultdict(list) - for req in waiting_queue: + for req in forward_queue: last_node_to_reqs[req.last_node].append(req) node_to_weight = defaultdict(int) @@ -67,7 +67,7 @@ class ScheduleHeuristic: self.get_dfs_priority( self.tree_cache.root_node, node_to_weight, last_node_to_reqs, q ) - assert len(q) == len(waiting_queue) + assert len(q) == len(forward_queue) return q else: raise ValueError(f"Unknown schedule_heuristic: {self.schedule_heuristic}") diff --git a/python/sglang/srt/managers/controller/tp_worker.py b/python/sglang/srt/managers/controller/tp_worker.py index 631ddba95..94e535d14 100644 --- a/python/sglang/srt/managers/controller/tp_worker.py +++ b/python/sglang/srt/managers/controller/tp_worker.py @@ -77,10 +77,6 @@ class ModelTpServer: self.schedule_heuristic = server_args.schedule_heuristic self.disable_regex_jump_forward = server_args.disable_regex_jump_forward - # Chunked prefill - self.chunked_prefill_size = server_args.chunked_prefill_size - self.current_inflight_req = None - # Init model and tokenizer self.model_config = ModelConfig( server_args.model_path, @@ -161,7 +157,7 @@ class ModelTpServer: self.token_to_kv_pool = self.model_runner.token_to_kv_pool # Init running status - self.waiting_queue: List[Req] = [] + self.forward_queue: List[Req] = [] self.running_batch: Batch = None self.out_pyobjs = [] self.decode_forward_ct = 0 @@ -224,7 +220,6 @@ class ModelTpServer: # Run a new prefill batch self.forward_prefill_batch(new_batch) self.cache_filled_batch(new_batch) - self.filter_out_inflight(new_batch) if not new_batch.is_empty(): if self.running_batch is None: @@ -266,7 +261,7 @@ class ModelTpServer: f"#token: {num_used}, " f"token usage: {num_used / self.max_total_num_tokens:.2f}, " f"gen throughput (token/s): {throughput:.2f}, " - f"#queue-req: {len(self.waiting_queue)}" + f"#queue-req: {len(self.forward_queue)}" ) def check_memory(self): @@ -333,10 +328,9 @@ class ModelTpServer: ), self.max_req_input_len - 1 - len(req.origin_input_ids), ) - self.waiting_queue.append(req) + self.forward_queue.append(req) def get_new_prefill_batch(self) -> Optional[Batch]: - # TODO(lsyin): organize this function running_bs = ( len(self.running_batch.reqs) if self.running_batch is not None else 0 ) @@ -344,7 +338,7 @@ class ModelTpServer: return # Compute matched prefix length - for req in self.waiting_queue: + for req in self.forward_queue: req.input_ids = req.origin_input_ids + req.output_ids prefix_indices, last_node = self.tree_cache.match_prefix(req.input_ids) if req.return_logprob: @@ -354,7 +348,7 @@ class ModelTpServer: req.last_node = last_node # Get priority queue - self.waiting_queue = self.scheduler.get_priority_queue(self.waiting_queue) + self.forward_queue = self.scheduler.get_priority_queue(self.forward_queue) # Add requests if there is available space can_run_list = [] @@ -373,33 +367,7 @@ class ModelTpServer: ] ) - # Handle the current inflight request - take_inflight = 0 - if self.current_inflight_req: - take_inflight = 1 - r = self.current_inflight_req - r.input_ids = r.origin_input_ids + r.output_ids - truncated = ( - len(r.input_ids) - len(r.prefix_indices) > self.chunked_prefill_size - ) - r.extend_input_len = min( - len(r.input_ids) - len(r.prefix_indices), self.chunked_prefill_size - ) - r.input_ids = r.input_ids[: len(r.prefix_indices) + r.extend_input_len] - can_run_list.append(r) - - if not truncated: - # Finish inflight - self.current_inflight_req = None - new_batch_total_tokens += ( - r.extend_input_len + r.sampling_params.max_new_tokens - ) - new_batch_input_tokens += r.extend_input_len - else: - new_batch_total_tokens += r.extend_input_len - new_batch_input_tokens += r.extend_input_len - - for req in self.waiting_queue: + for req in self.forward_queue: if req.return_logprob and req.normalized_prompt_logprob is None: # Need at least two tokens to compute normalized logprob if req.extend_input_len < 2: @@ -441,36 +409,11 @@ class ModelTpServer: break else: # Add this request to the running batch - if ( - new_batch_input_tokens + req.extend_input_len - <= self.chunked_prefill_size - or ( - req.return_logprob and req.normalized_prompt_logprob is None - ) - ): - can_run_list.append(req) - new_batch_total_tokens += ( - req.extend_input_len + req.sampling_params.max_new_tokens - ) - new_batch_input_tokens += req.extend_input_len - else: - trunc_len = self.chunked_prefill_size - new_batch_input_tokens - - if trunc_len <= 0: - # Undo locking - delta = self.tree_cache.dec_lock_ref(req.last_node) - available_size += delta - break - - req.extend_input_len = trunc_len - req.input_ids = req.input_ids[ - : len(req.prefix_indices) + req.extend_input_len - ] - can_run_list.append(req) - self.current_inflight_req = req - new_batch_input_tokens += req.extend_input_len - new_batch_total_tokens += req.extend_input_len - break + can_run_list.append(req) + new_batch_total_tokens += ( + req.extend_input_len + req.sampling_params.max_new_tokens + ) + new_batch_input_tokens += req.extend_input_len else: break @@ -497,7 +440,7 @@ class ModelTpServer: f"#cached-token: {hit_tokens}, " f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, " f"#running-req: {running_bs}, " - f"#queue-req: {len(self.waiting_queue) - len(can_run_list) + take_inflight}" + f"#queue-req: {len(self.forward_queue) - len(can_run_list)}" ) # Return the new batch @@ -507,7 +450,7 @@ class ModelTpServer: self.token_to_kv_pool, self.tree_cache, ) - self.waiting_queue = [x for x in self.waiting_queue if x not in can_run_list] + self.forward_queue = [x for x in self.forward_queue if x not in can_run_list] return new_batch def forward_prefill_batch(self, batch: Batch): @@ -539,10 +482,9 @@ class ModelTpServer: # Check finish conditions pt = 0 for i, req in enumerate(batch.reqs): - if req is not self.current_inflight_req: - req.completion_tokens_wo_jump_forward += 1 - req.output_ids.append(next_token_ids[i]) - req.check_finished() + req.completion_tokens_wo_jump_forward += 1 + req.output_ids.append(next_token_ids[i]) + req.check_finished() if req.return_logprob: self.add_logprob_return_values(i, req, pt, next_token_ids, output) @@ -603,7 +545,7 @@ class ModelTpServer: req_pool_indices_cpu = batch.req_pool_indices.cpu().numpy() for i, req in enumerate(batch.reqs): new_prefix_indices, new_last_node = self.tree_cache.cache_req( - token_ids=tuple(req.input_ids), + token_ids=tuple(req.origin_input_ids + req.output_ids)[:-1], last_uncached_pos=len(req.prefix_indices), req_pool_idx=req_pool_indices_cpu[i], del_in_memory_pool=False, @@ -611,10 +553,6 @@ class ModelTpServer: ) req.prefix_indices, req.last_node = new_prefix_indices, new_last_node - if req is self.current_inflight_req: - # inflight request would get a new req idx - self.req_to_token_pool.free(int(req_pool_indices_cpu[i])) - def forward_decode_batch(self, batch: Batch): # Check if decode out of memory if not batch.check_decode_mem(): @@ -628,7 +566,7 @@ class ModelTpServer: f"#retracted_reqs: {len(retracted_reqs)}, " f"#new_token_ratio: {old_ratio:.4f} -> {self.new_token_ratio:.4f}" ) - self.waiting_queue.extend(retracted_reqs) + self.forward_queue.extend(retracted_reqs) else: self.new_token_ratio = max( self.new_token_ratio - self.new_token_ratio_decay, @@ -638,7 +576,7 @@ class ModelTpServer: if not self.disable_regex_jump_forward: # Check for jump-forward jump_forward_reqs = batch.check_for_jump_forward(self.model_runner) - self.waiting_queue.extend(jump_forward_reqs) + self.forward_queue.extend(jump_forward_reqs) if batch.is_empty(): return @@ -773,18 +711,8 @@ class ModelTpServer: else: batch.reqs = [] - def filter_out_inflight(self, batch: Batch): - # TODO(lsyin): reduce the overhead, make a special version for this - if self.current_inflight_req is None: - return - - unfinished_indices = list(range(len(batch.reqs))) - unfinished_indices.remove(batch.reqs.index(self.current_inflight_req)) - - batch.filter_batch(unfinished_indices) - def flush_cache(self): - if len(self.waiting_queue) == 0 and ( + if len(self.forward_queue) == 0 and ( self.running_batch is None or len(self.running_batch.reqs) == 0 ): self.tree_cache.reset() @@ -797,20 +725,20 @@ class ModelTpServer: else: warnings.warn( f"Cache not flushed because there are pending requests. " - f"#queue-req: {len(self.waiting_queue)}, " + f"#queue-req: {len(self.forward_queue)}, " f"#running-req: {0 if self.running_batch is None else len(self.running_batch.reqs)}" ) def abort_request(self, recv_req): # Delete requests in the waiting queue to_del = None - for i, req in enumerate(self.waiting_queue): + for i, req in enumerate(self.forward_queue): if req.rid == recv_req.rid: to_del = i break if to_del is not None: - del self.waiting_queue[to_del] + del self.forward_queue[to_del] # Delete requests in the running batch if self.running_batch: diff --git a/python/sglang/srt/memory_pool.py b/python/sglang/srt/memory_pool.py index fa38ee41c..d4948b525 100644 --- a/python/sglang/srt/memory_pool.py +++ b/python/sglang/srt/memory_pool.py @@ -45,7 +45,7 @@ class ReqToTokenPool: return select_index - def free(self, free_index): + def free(self, free_index: int): self.mem_state[free_index] = True if isinstance(free_index, (int,)): self.can_use_mem_size += 1 diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py index 41d71545a..e4801deeb 100644 --- a/python/sglang/srt/server.py +++ b/python/sglang/srt/server.py @@ -175,39 +175,6 @@ def _set_torch_compile_config(): torch._dynamo.config.accumulated_cache_size_limit = 256 -def set_envs_and_config(server_args: ServerArgs): - # Set global environments - os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" - os.environ["NCCL_CUMEM_ENABLE"] = "0" - os.environ["NCCL_NVLS_ENABLE"] = "0" - os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1" - - # Set ulimit - set_ulimit() - - # Enable show time cost for debugging - if server_args.show_time_cost: - enable_show_time_cost() - - # Disable disk cache - if server_args.disable_disk_cache: - disable_cache() - - # Fix triton bugs - if server_args.tp_size * server_args.dp_size > 1: - # FIXME: remove this after https://github.com/triton-lang/triton/pull/4295 is used as a dependency. - maybe_set_triton_cache_manager() - - # Set torch compile config - if server_args.enable_torch_compile: - _set_torch_compile_config() - - # Set global chat template - if server_args.chat_template: - # TODO: replace this with huggingface transformers template - load_chat_template_for_openai_api(server_args.chat_template) - - def launch_server( server_args: ServerArgs, model_overide_args: Optional[dict] = None, @@ -223,6 +190,16 @@ def launch_server( format="%(message)s", ) + # Set global environments + os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" + os.environ["NCCL_CUMEM_ENABLE"] = "0" + os.environ["NCCL_NVLS_ENABLE"] = "0" + os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1" + set_ulimit() + if server_args.show_time_cost: + enable_show_time_cost() + if server_args.disable_disk_cache: + disable_cache() if not server_args.disable_flashinfer: assert_pkg_version( "flashinfer", @@ -231,8 +208,14 @@ def launch_server( "reinstall the latest version by following the instructions " "at https://docs.flashinfer.ai/installation.html.", ) - - set_envs_and_config(server_args) + if server_args.tp_size * server_args.dp_size > 1: + # FIXME: remove this after https://github.com/triton-lang/triton/pull/4295 is used as a dependency. + maybe_set_triton_cache_manager() + if server_args.chat_template: + # TODO: replace this with huggingface transformers template + load_chat_template_for_openai_api(server_args.chat_template) + if server_args.enable_torch_compile: + _set_torch_compile_config() # Allocate ports server_args.port, server_args.additional_ports = allocate_init_ports( diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 224b8d879..5c22c48e6 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -65,9 +65,6 @@ class ServerArgs: dp_size: int = 1 load_balance_method: str = "round_robin" - # Chunked Prefill - chunked_prefill_size: Optional[int] = None - # Optimization/debug options disable_flashinfer: bool = False disable_flashinfer_sampling: bool = False @@ -86,8 +83,6 @@ class ServerArgs: node_rank: Optional[int] = None def __post_init__(self): - if self.chunked_prefill_size is None: - self.chunked_prefill_size = int(10**9) if self.tokenizer_path is None: self.tokenizer_path = self.model_path if self.mem_fraction_static is None: @@ -228,7 +223,7 @@ class ServerArgs: parser.add_argument( "--max-num-reqs", type=int, - default=ServerArgs.max_num_reqs, + default=None, help="The maximum number of requests to serve in the memory pool. If the model have a large context length, you may need to decrease this value to avoid out-of-memory errors.", ) parser.add_argument( @@ -316,18 +311,10 @@ class ServerArgs: help="The nccl init address of multi-node server.", ) parser.add_argument( - "--nnodes", type=int, default=ServerArgs.nnodes, help="The number of nodes." + "--nnodes", type=int, default=1, help="The number of nodes." ) parser.add_argument("--node-rank", type=int, help="The node rank.") - # Chunked prefill - parser.add_argument( - "--chunked-prefill-size", - type=int, - default=ServerArgs.chunked_prefill_size, - help="The size of the chunked prefill.", - ) - # Optimization/debug options parser.add_argument( "--disable-flashinfer", @@ -406,10 +393,6 @@ class ServerArgs: self.dp_size > 1 and self.node_rank is not None ), "multi-node data parallel is not supported" - assert not ( - self.chunked_prefill_size is not None and self.disable_radix_cache - ), "chunked prefill is not supported with radix cache disabled currently" - @dataclasses.dataclass class PortArgs: