From 39191c851532b8899b81c8dfac1bf558ee6be160 Mon Sep 17 00:00:00 2001 From: Liangsheng Yin Date: Mon, 13 May 2024 12:47:13 +0800 Subject: [PATCH] Cache optimizations (#418) --- python/sglang/global_config.py | 3 + python/sglang/srt/backend_config.py | 13 ---- .../sglang/srt/managers/router/infer_batch.py | 36 +++++------ python/sglang/srt/managers/router/manager.py | 8 +-- .../sglang/srt/managers/router/model_rpc.py | 41 ++++++++---- .../sglang/srt/managers/router/radix_cache.py | 64 ++++++++++++++----- .../sglang/srt/managers/router/scheduler.py | 45 +++++-------- python/sglang/srt/server_args.py | 3 +- 8 files changed, 117 insertions(+), 96 deletions(-) delete mode 100644 python/sglang/srt/backend_config.py diff --git a/python/sglang/global_config.py b/python/sglang/global_config.py index ef0853b7e..e746d7f1d 100644 --- a/python/sglang/global_config.py +++ b/python/sglang/global_config.py @@ -25,5 +25,8 @@ class GlobalConfig: # adjust_cache: Adjust the position embedding of KV cache. self.concate_and_append_mode = "no_adjust" + # Request dependency time due to network delay + self.request_dependency_time = 0.03 + global_config = GlobalConfig() diff --git a/python/sglang/srt/backend_config.py b/python/sglang/srt/backend_config.py deleted file mode 100644 index 107ae0ecb..000000000 --- a/python/sglang/srt/backend_config.py +++ /dev/null @@ -1,13 +0,0 @@ -""" -Backend configurations, may vary with different serving platforms. -""" - -from dataclasses import dataclass - - -@dataclass -class BackendConfig: - extend_dependency_time: float = 0.03 - - -GLOBAL_BACKEND_CONFIG = BackendConfig() diff --git a/python/sglang/srt/managers/router/infer_batch.py b/python/sglang/srt/managers/router/infer_batch.py index a46e1e9db..a2420fa93 100644 --- a/python/sglang/srt/managers/router/infer_batch.py +++ b/python/sglang/srt/managers/router/infer_batch.py @@ -335,20 +335,20 @@ class Batch: req = self.reqs[idx] retracted_reqs.append(req) - self.tree_cache.dec_ref_counter(req.last_node) + # TODO: apply more fine-grained retraction + last_uncached_pos = len(req.prefix_indices) + token_indices = self.req_to_token_pool.req_to_token[ + req_pool_indices_cpu[idx] + ][last_uncached_pos : seq_lens_cpu[idx]] + self.token_to_kv_pool.dec_refs(token_indices) + + self.tree_cache.dec_lock_ref(req.last_node) req.prefix_indices = None req.last_node = None req.extend_input_len = 0 req.output_ids = [] req.regex_fsm_state = 0 - # TODO: apply more fine-grained retraction - - token_indices = self.req_to_token_pool.req_to_token[ - req_pool_indices_cpu[idx] - ][: seq_lens_cpu[idx]] - self.token_to_kv_pool.dec_refs(token_indices) - self.filter_batch(sorted_indices) return retracted_reqs @@ -367,20 +367,18 @@ class Batch: if len(jump_forward_str) <= 1: continue - # insert the old request into tree_cache - token_ids_in_memory = tuple(req.input_ids + req.output_ids)[:-1] if req_pool_indices_cpu is None: req_pool_indices_cpu = self.req_pool_indices.tolist() - req_pool_idx = req_pool_indices_cpu[i] - indices = self.req_to_token_pool.req_to_token[ - req_pool_idx, : len(token_ids_in_memory) - ] - prefix_len = self.tree_cache.insert( - token_ids_in_memory, indices.clone() + + # insert the old request into tree_cache + self.tree_cache.cache_req( + token_ids=tuple(req.input_ids + req.output_ids)[:-1], + last_uncached_pos=len(req.prefix_indices), + req_pool_idx=req_pool_indices_cpu[i], ) - self.token_to_kv_pool.dec_refs(indices[:prefix_len]) - self.req_to_token_pool.free(req_pool_idx) - self.tree_cache.dec_ref_counter(req.last_node) + + # unlock the last node + self.tree_cache.dec_lock_ref(req.last_node) # jump-forward req.jump_forward_and_retokenize(jump_forward_str, next_state) diff --git a/python/sglang/srt/managers/router/manager.py b/python/sglang/srt/managers/router/manager.py index c331ae2bb..af0664509 100644 --- a/python/sglang/srt/managers/router/manager.py +++ b/python/sglang/srt/managers/router/manager.py @@ -5,7 +5,7 @@ import uvloop import zmq import zmq.asyncio -from sglang.srt.backend_config import GLOBAL_BACKEND_CONFIG +from sglang import global_config from sglang.srt.managers.router.model_rpc import ModelRpcClient from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.utils import get_exception_traceback @@ -30,7 +30,7 @@ class RouterManager: self.recv_reqs = [] # Init some configs - self.extend_dependency_time = GLOBAL_BACKEND_CONFIG.extend_dependency_time + self.request_dependency_time = global_config.request_dependency_time async def loop_for_forward(self): while True: @@ -46,9 +46,9 @@ class RouterManager: if len(out_pyobjs) != 0: has_finished = any([obj.finished for obj in out_pyobjs]) if has_finished: - if self.extend_dependency_time > 0: + if self.request_dependency_time > 0: slept = True - await asyncio.sleep(self.extend_dependency_time) + await asyncio.sleep(self.request_dependency_time) if not slept: await asyncio.sleep(0.0006) diff --git a/python/sglang/srt/managers/router/model_rpc.py b/python/sglang/srt/managers/router/model_rpc.py index f85faecd0..f9e7153a8 100644 --- a/python/sglang/srt/managers/router/model_rpc.py +++ b/python/sglang/srt/managers/router/model_rpc.py @@ -117,7 +117,11 @@ class ModelRpcServer: logger.info(f"server_args: {server_args.print_mode_args()}") # Init cache - self.tree_cache = RadixCache(disable=server_args.disable_radix_cache) + self.tree_cache = RadixCache( + req_to_token_pool=self.model_runner.req_to_token_pool, + token_to_kv_pool=self.model_runner.token_to_kv_pool, + disable=server_args.disable_radix_cache, + ) self.tree_cache_metrics = {"total": 0, "hit": 0} self.scheduler = Scheduler( self.schedule_heuristic, @@ -203,6 +207,8 @@ class ModelRpcServer: # Run new fill batch self.forward_fill_batch(new_batch) + self.cache_filled_batch(new_batch) + if not new_batch.is_empty(): if self.running_batch is None: self.running_batch = new_batch @@ -349,20 +355,19 @@ class ModelRpcServer: and req.extend_input_len + new_batch_input_tokens < self.max_prefill_num_token ): - delta = self.tree_cache.inc_ref_counter(req.last_node) + delta = self.tree_cache.inc_lock_ref(req.last_node) available_size += delta if not ( req.extend_input_len + req.max_new_tokens() + new_batch_total_tokens < available_size ): - # Undo the insertion - delta = self.tree_cache.dec_ref_counter(req.last_node) + # Undo locking + delta = self.tree_cache.dec_lock_ref(req.last_node) available_size += delta break else: # Add this request to the running batch - self.token_to_kv_pool.add_refs(req.prefix_indices) can_run_list.append(req) new_batch_total_tokens += ( req.extend_input_len + req.max_new_tokens() @@ -477,6 +482,18 @@ class ModelRpcServer: self.handle_finished_requests(batch) + def cache_filled_batch(self, batch: Batch): + req_pool_indices_cpu = batch.req_pool_indices.cpu().tolist() + for i, req in enumerate(batch.reqs): + new_prefix_indices, new_last_node = self.tree_cache.cache_req( + token_ids=tuple(req.input_ids + req.output_ids)[:-1], + last_uncached_pos=len(req.prefix_indices), + req_pool_idx=req_pool_indices_cpu[i], + del_in_memory_pool=False, + old_last_node=req.last_node, + ) + req.prefix_indices, req.last_node = new_prefix_indices, new_last_node + def forward_decode_batch(self, batch: Batch): # check if decode out of memory if not batch.check_decode_mem(): @@ -636,17 +653,13 @@ class ModelRpcServer: req_pool_indices_cpu = batch.req_pool_indices.tolist() for i in finished_indices: req = batch.reqs[i] - req_pool_idx = req_pool_indices_cpu[i] - token_ids = tuple(req.input_ids + req.output_ids) - seq_len = len(token_ids) - 1 - indices = self.req_to_token_pool.req_to_token[req_pool_idx, :seq_len] - prefix_len = self.tree_cache.insert( - token_ids[:seq_len], indices.clone() + self.tree_cache.cache_req( + token_ids=tuple(req.input_ids + req.output_ids)[:-1], + last_uncached_pos=len(req.prefix_indices), + req_pool_idx=req_pool_indices_cpu[i], ) - self.token_to_kv_pool.dec_refs(indices[:prefix_len]) - self.req_to_token_pool.free(req_pool_idx) - self.tree_cache.dec_ref_counter(req.last_node) + self.tree_cache.dec_lock_ref(req.last_node) # Update batch tensors if unfinished_indices: diff --git a/python/sglang/srt/managers/router/radix_cache.py b/python/sglang/srt/managers/router/radix_cache.py index ce097afa7..855a10bb6 100644 --- a/python/sglang/srt/managers/router/radix_cache.py +++ b/python/sglang/srt/managers/router/radix_cache.py @@ -11,7 +11,7 @@ class TreeNode: self.parent = None self.key = None self.value = None - self.ref_counter = 0 + self.lock_ref = 0 self.last_access_time = time.time() def __lt__(self, other: "TreeNode"): @@ -28,7 +28,9 @@ def _key_match(key0, key1): class RadixCache: - def __init__(self, disable: bool = False): + def __init__(self, req_to_token_pool, token_to_kv_pool, disable: bool = False): + self.req_to_token_pool = req_to_token_pool + self.token_to_kv_pool = token_to_kv_pool self.disable = disable self.reset() @@ -38,7 +40,7 @@ class RadixCache: self.root_node = TreeNode() self.root_node.key = [] self.root_node.value = [] - self.root_node.ref_counter = 1 + self.root_node.lock_ref = 1 self.evictable_size_ = 0 def match_prefix(self, key): @@ -50,6 +52,8 @@ class RadixCache: self._match_prefix_helper(self.root_node, key, value, last_node) if value: value = torch.concat(value) + else: + value = torch.tensor([], dtype=torch.int64) return value, last_node[0] def insert(self, key, value=None): @@ -60,6 +64,34 @@ class RadixCache: value = [x for x in key] return self._insert_helper(self.root_node, key, value) + def cache_req( + self, + token_ids, + last_uncached_pos, + req_pool_idx, + del_in_memory_pool=True, + old_last_node=None, + ): + # Insert the request into radix cache + indices = self.req_to_token_pool.req_to_token[req_pool_idx, : len(token_ids)] + new_prefix_len = self.insert(token_ids, indices.clone()) + + # Radix Cache takes one ref in memory pool + self.token_to_kv_pool.dec_refs(indices[last_uncached_pos:new_prefix_len]) + + if del_in_memory_pool: + self.req_to_token_pool.free(req_pool_idx) + else: + cached_indices, new_last_node = self.match_prefix(token_ids) + assert len(cached_indices) == len(token_ids) + + self.req_to_token_pool.req_to_token[ + req_pool_idx, last_uncached_pos : len(cached_indices) + ] = cached_indices[last_uncached_pos:] + self.dec_lock_ref(old_last_node) + self.inc_lock_ref(new_last_node) + return cached_indices, new_last_node + def pretty_print(self): self._print_helper(self.root_node, 0) print(f"#tokens: {self.total_size()}") @@ -80,7 +112,7 @@ class RadixCache: if x == self.root_node: break - if x.ref_counter > 0: + if x.lock_ref > 0: continue num_evicted += evict_callback(x.value) @@ -89,23 +121,23 @@ class RadixCache: if len(x.parent.children) == 0: heapq.heappush(leaves, x.parent) - def inc_ref_counter(self, node): + def inc_lock_ref(self, node: TreeNode): delta = 0 while node != self.root_node: - if node.ref_counter == 0: + if node.lock_ref == 0: self.evictable_size_ -= len(node.value) delta -= len(node.value) - node.ref_counter += 1 + node.lock_ref += 1 node = node.parent return delta - def dec_ref_counter(self, node): + def dec_lock_ref(self, node: TreeNode): delta = 0 while node != self.root_node: - if node.ref_counter == 1: + if node.lock_ref == 1: self.evictable_size_ += len(node.value) delta += len(node.value) - node.ref_counter -= 1 + node.lock_ref -= 1 node = node.parent return delta @@ -131,12 +163,12 @@ class RadixCache: last_node[0] = child self._match_prefix_helper(child, key[prefix_len:], value, last_node) - def _split_node(self, key, child, split_len): + def _split_node(self, key, child: TreeNode, split_len): # new_node -> child new_node = TreeNode() new_node.children = {key[split_len:][0]: child} new_node.parent = child.parent - new_node.ref_counter = child.ref_counter + new_node.lock_ref = child.lock_ref new_node.key = child.key[:split_len] new_node.value = child.value[:split_len] child.parent = new_node @@ -176,11 +208,9 @@ class RadixCache: self.evictable_size_ += len(value) return 0 - def _print_helper(self, node, indent): + def _print_helper(self, node: TreeNode, indent): for _, child in node.children.items(): - print( - " " * indent, len(child.key), child.key[:10], f"r={child.ref_counter}" - ) + print(" " * indent, len(child.key), child.key[:10], f"r={child.lock_ref}") self._print_helper(child, indent=indent + 2) def _delete_leaf(self, node): @@ -211,7 +241,7 @@ class RadixCache: if __name__ == "__main__": - tree = RadixCache() + tree = RadixCache(None, None, False) tree.insert("Hello") tree.insert("Hello") diff --git a/python/sglang/srt/managers/router/scheduler.py b/python/sglang/srt/managers/router/scheduler.py index 9affd970f..806151931 100644 --- a/python/sglang/srt/managers/router/scheduler.py +++ b/python/sglang/srt/managers/router/scheduler.py @@ -27,44 +27,33 @@ class Scheduler: return forward_queue elif self.schedule_heuristic == "fcfs": return forward_queue - elif self.schedule_heuristic == "weight": + elif self.schedule_heuristic == "dfs-weight": last_node_to_reqs = defaultdict(list) for req in forward_queue: last_node_to_reqs[req.last_node].append(req) - for node in last_node_to_reqs: - last_node_to_reqs[node].sort(key=lambda x: -len(x.prefix_indices)) node_to_weight = defaultdict(int) - self._calc_weight_recursive( - self.tree_cache.root_node, last_node_to_reqs, node_to_weight - ) + for node in last_node_to_reqs: + node_to_weight[node] = len(last_node_to_reqs[node]) + self.calc_weight(self.tree_cache.root_node, node_to_weight) - tmp_queue = [] - self._get_weight_priority_recursive( - self.tree_cache.root_node, node_to_weight, last_node_to_reqs, tmp_queue + q = [] + self.get_dfs_priority( + self.tree_cache.root_node, node_to_weight, last_node_to_reqs, q ) - assert len(tmp_queue) == len(forward_queue) - return tmp_queue + assert len(q) == len(forward_queue) + return q else: raise ValueError(f"Unknown schedule_heuristic: {self.schedule_heuristic}") - def _calc_weight_recursive(self, cur_node, last_node_to_reqs, node_to_weight): - node_to_weight[cur_node] = 1 - if cur_node in last_node_to_reqs: - node_to_weight[cur_node] += len(last_node_to_reqs[cur_node]) + def calc_weight(self, cur_node, node_to_weight): for child in cur_node.children.values(): - self._calc_weight_recursive(child, last_node_to_reqs, node_to_weight) + self.calc_weight(child, node_to_weight) node_to_weight[cur_node] += node_to_weight[child] - def _get_weight_priority_recursive( - self, cur_node, node_to_wight, last_node_to_reqs, tmp_queue - ): - visit_list = [child for child in cur_node.children.values()] - visit_list.sort(key=lambda x: -node_to_wight[x]) - # for node in visit_list: - # print(f"{node_to_wight[node]} {len(node.value) if node.value is not None else 0}") - for child in visit_list: - self._get_weight_priority_recursive( - child, node_to_wight, last_node_to_reqs, tmp_queue - ) - tmp_queue.extend(last_node_to_reqs[cur_node]) + def get_dfs_priority(self, cur_node, node_to_priority, last_node_to_reqs, q): + childs = [child for child in cur_node.children.values()] + childs.sort(key=lambda x: -node_to_priority[x]) + for child in childs: + self.get_dfs_priority(child, node_to_priority, last_node_to_reqs, q) + q.extend(last_node_to_reqs[cur_node]) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 43e7514d5..ccf322c0a 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -149,7 +149,8 @@ class ServerArgs: "--schedule-heuristic", type=str, default=ServerArgs.schedule_heuristic, - help="Schudule mode: [lpm, weight, random, fcfs]", + choices=["lpm", "random", "fcfs", "dfs-weight"], + help="Scheduling Heuristic.", ) parser.add_argument( "--schedule-conservativeness",