diff --git a/python/sglang/srt/managers/router/infer_batch.py b/python/sglang/srt/managers/router/infer_batch.py index c52892c14..3920fe039 100644 --- a/python/sglang/srt/managers/router/infer_batch.py +++ b/python/sglang/srt/managers/router/infer_batch.py @@ -236,9 +236,8 @@ class Batch: extend_num_tokens = seq_lens.sum() - prefix_lens.sum() out_cache_loc = self.token_to_kv_pool.alloc(extend_num_tokens) if out_cache_loc is None: - if not self.tree_cache.disable: - self.tree_cache.evict(extend_num_tokens, self.token_to_kv_pool.free) - out_cache_loc = self.token_to_kv_pool.alloc(extend_num_tokens) + self.tree_cache.evict(extend_num_tokens, self.token_to_kv_pool.dec_refs) + out_cache_loc = self.token_to_kv_pool.alloc(extend_num_tokens) if out_cache_loc is None: print("Prefill out of memory. This should never happen.") @@ -307,8 +306,8 @@ class Batch: if self.token_to_kv_pool.available_size() >= bs: return True - if not self.tree_cache.disable: - self.tree_cache.evict(bs, self.token_to_kv_pool.free) + self.tree_cache.evict(bs, self.token_to_kv_pool.dec_refs) + if self.token_to_kv_pool.available_size() >= bs: return True @@ -341,7 +340,7 @@ class Batch: token_indices = self.req_to_token_pool.req_to_token[ req_pool_indices_np[idx] ][: seq_lens_np[idx]] - self.token_to_kv_pool.free(token_indices) + self.token_to_kv_pool.dec_refs(token_indices) self.filter_batch(sorted_indices) @@ -372,7 +371,7 @@ class Batch: prefix_len = self.tree_cache.insert( token_ids_in_memory, indices.clone() ) - self.token_to_kv_pool.free(indices[:prefix_len]) + self.token_to_kv_pool.dec_refs(indices[:prefix_len]) self.req_to_token_pool.free(req_pool_idx) self.tree_cache.dec_ref_counter(req.last_node) diff --git a/python/sglang/srt/managers/router/model_rpc.py b/python/sglang/srt/managers/router/model_rpc.py index 0ca46c854..02c98560b 100644 --- a/python/sglang/srt/managers/router/model_rpc.py +++ b/python/sglang/srt/managers/router/model_rpc.py @@ -113,7 +113,7 @@ class ModelRpcServer: logger.info(server_args.get_optional_modes_logging()) # Init cache - self.tree_cache = RadixCache(server_args.disable_radix_cache) + self.tree_cache = RadixCache(disable=server_args.disable_radix_cache) self.tree_cache_metrics = {"total": 0, "hit": 0} self.scheduler = Scheduler( self.schedule_heuristic, @@ -628,7 +628,7 @@ class ModelRpcServer: token_ids[:seq_len], indices.clone() ) - self.token_to_kv_pool.free(indices[:prefix_len]) + self.token_to_kv_pool.dec_refs(indices[:prefix_len]) self.req_to_token_pool.free(req_pool_idx) self.tree_cache.dec_ref_counter(req.last_node) diff --git a/python/sglang/srt/managers/router/radix_cache.py b/python/sglang/srt/managers/router/radix_cache.py index c7bd9cb6b..ce097afa7 100644 --- a/python/sglang/srt/managers/router/radix_cache.py +++ b/python/sglang/srt/managers/router/radix_cache.py @@ -1,8 +1,6 @@ import heapq import time from collections import defaultdict -from dataclasses import dataclass -from typing import Tuple import torch @@ -16,23 +14,23 @@ class TreeNode: self.ref_counter = 0 self.last_access_time = time.time() - def __lt__(self, other): + def __lt__(self, other: "TreeNode"): return self.last_access_time < other.last_access_time -def match(key, seq): +def _key_match(key0, key1): i = 0 - for k, w in zip(key, seq): - if k != w: + for k0, k1 in zip(key0, key1): + if k0 != k1: break i += 1 return i class RadixCache: - def __init__(self, disable=False): - self.reset() + def __init__(self, disable: bool = False): self.disable = disable + self.reset() ##### Public API ##### @@ -71,7 +69,7 @@ class RadixCache: def evict(self, num_tokens, evict_callback): if self.disable: - raise RuntimeError() + return leaves = self._collect_leaves() heapq.heapify(leaves) @@ -115,6 +113,7 @@ class RadixCache: return self.evictable_size_ ##### Internal Helper Functions ##### + def _match_prefix_helper(self, node, key, value, last_node): node.last_access_time = time.time() if len(key) == 0: @@ -122,7 +121,7 @@ class RadixCache: if key[0] in node.children.keys(): child = node.children[key[0]] - prefix_len = match(child.key, key) + prefix_len = _key_match(child.key, key) if prefix_len < len(child.key): new_node = self._split_node(child.key, child, prefix_len) value.append(new_node.value) @@ -153,7 +152,7 @@ class RadixCache: if key[0] in node.children.keys(): child = node.children[key[0]] - prefix_len = match(child.key, key) + prefix_len = _key_match(child.key, key) if prefix_len == len(child.key): if prefix_len == len(key): @@ -212,7 +211,7 @@ class RadixCache: if __name__ == "__main__": - tree = RadixCache(disable=False) + tree = RadixCache() tree.insert("Hello") tree.insert("Hello") diff --git a/python/sglang/srt/memory_pool.py b/python/sglang/srt/memory_pool.py index c93cb6044..33f4b8784 100644 --- a/python/sglang/srt/memory_pool.py +++ b/python/sglang/srt/memory_pool.py @@ -31,9 +31,6 @@ class ReqToTokenPool: self.can_use_mem_size += free_index.shape[0] self.mem_state[free_index] = 1 - # if self.can_use_mem_size == len(self.mem_state): - # print(f"ReqToTokenPool: freed all. size = {self.can_use_mem_size}.") - def clear(self): self.mem_state.fill_(1) self.can_use_mem_size = len(self.mem_state) @@ -42,7 +39,7 @@ class ReqToTokenPool: class TokenToKVPool: def __init__(self, size, dtype, head_num, head_dim, layer_num): self.mem_state = torch.zeros((size,), dtype=torch.int16, device="cuda") - self.alloc_ct = 0 + self.total_ref_ct = 0 # [size, key/value, head_num, head_dim] for each layer self.kv_data = [ @@ -83,9 +80,6 @@ class TokenToKVPool: self.add_refs(select_index) return select_index.to(torch.int32), start_loc, start_loc + need_size - def free(self, free_index): - return self.decrease_refs(free_index) - def used_size(self): return len(torch.nonzero(self.mem_state).squeeze(1)) @@ -93,20 +87,17 @@ class TokenToKVPool: return torch.sum(self.mem_state == 0).item() def add_refs(self, token_index: torch.Tensor): - self.alloc_ct += len(token_index) + self.total_ref_ct += len(token_index) self.mem_state[token_index] += 1 - def decrease_refs(self, token_index: torch.Tensor): - self.alloc_ct -= len(token_index) + def dec_refs(self, token_index: torch.Tensor): + self.total_ref_ct -= len(token_index) self.mem_state[token_index] -= 1 num_freed = torch.sum(self.mem_state[token_index] == 0) - # if self.alloc_ct == 0: - # print(f"TokenToKVPool: freed all. size = {len(self.mem_state)}.") - return num_freed def clear(self): self.mem_state.fill_(0) - self.alloc_ct = 0 + self.total_ref_ct = 0 diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py index b3395f162..79ed26c93 100644 --- a/python/sglang/srt/server.py +++ b/python/sglang/srt/server.py @@ -500,7 +500,7 @@ async def v1_chat_completions(raw_request: Request): return response -def launch_server(server_args, pipe_finish_writer): +def launch_server(server_args: ServerArgs, pipe_finish_writer): global tokenizer_manager global chat_template_name diff --git a/test/srt/model/test_llama_low_api.py b/test/srt/model/test_llama_low_api.py index 80a79e0c6..20b59e5c7 100644 --- a/test/srt/model/test_llama_low_api.py +++ b/test/srt/model/test_llama_low_api.py @@ -105,7 +105,7 @@ def test_generate_worker( for i in range(batch_size): req_idx = req_pool_indices[i].item() - model.token_to_kv_pool.free( + model.token_to_kv_pool.dec_refs( model.req_to_token_pool.req_to_token[req_idx, : seq_lens[i]] ) model.req_to_token_pool.free(req_pool_indices)