Cache optimizations (#418)
This commit is contained in:
@@ -25,5 +25,8 @@ class GlobalConfig:
|
||||
# adjust_cache: Adjust the position embedding of KV cache.
|
||||
self.concate_and_append_mode = "no_adjust"
|
||||
|
||||
# Request dependency time due to network delay
|
||||
self.request_dependency_time = 0.03
|
||||
|
||||
|
||||
global_config = GlobalConfig()
|
||||
|
||||
@@ -1,13 +0,0 @@
|
||||
"""
|
||||
Backend configurations, may vary with different serving platforms.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class BackendConfig:
|
||||
extend_dependency_time: float = 0.03
|
||||
|
||||
|
||||
GLOBAL_BACKEND_CONFIG = BackendConfig()
|
||||
@@ -335,20 +335,20 @@ class Batch:
|
||||
req = self.reqs[idx]
|
||||
retracted_reqs.append(req)
|
||||
|
||||
self.tree_cache.dec_ref_counter(req.last_node)
|
||||
# TODO: apply more fine-grained retraction
|
||||
last_uncached_pos = len(req.prefix_indices)
|
||||
token_indices = self.req_to_token_pool.req_to_token[
|
||||
req_pool_indices_cpu[idx]
|
||||
][last_uncached_pos : seq_lens_cpu[idx]]
|
||||
self.token_to_kv_pool.dec_refs(token_indices)
|
||||
|
||||
self.tree_cache.dec_lock_ref(req.last_node)
|
||||
req.prefix_indices = None
|
||||
req.last_node = None
|
||||
req.extend_input_len = 0
|
||||
req.output_ids = []
|
||||
req.regex_fsm_state = 0
|
||||
|
||||
# TODO: apply more fine-grained retraction
|
||||
|
||||
token_indices = self.req_to_token_pool.req_to_token[
|
||||
req_pool_indices_cpu[idx]
|
||||
][: seq_lens_cpu[idx]]
|
||||
self.token_to_kv_pool.dec_refs(token_indices)
|
||||
|
||||
self.filter_batch(sorted_indices)
|
||||
|
||||
return retracted_reqs
|
||||
@@ -367,20 +367,18 @@ class Batch:
|
||||
if len(jump_forward_str) <= 1:
|
||||
continue
|
||||
|
||||
# insert the old request into tree_cache
|
||||
token_ids_in_memory = tuple(req.input_ids + req.output_ids)[:-1]
|
||||
if req_pool_indices_cpu is None:
|
||||
req_pool_indices_cpu = self.req_pool_indices.tolist()
|
||||
req_pool_idx = req_pool_indices_cpu[i]
|
||||
indices = self.req_to_token_pool.req_to_token[
|
||||
req_pool_idx, : len(token_ids_in_memory)
|
||||
]
|
||||
prefix_len = self.tree_cache.insert(
|
||||
token_ids_in_memory, indices.clone()
|
||||
|
||||
# insert the old request into tree_cache
|
||||
self.tree_cache.cache_req(
|
||||
token_ids=tuple(req.input_ids + req.output_ids)[:-1],
|
||||
last_uncached_pos=len(req.prefix_indices),
|
||||
req_pool_idx=req_pool_indices_cpu[i],
|
||||
)
|
||||
self.token_to_kv_pool.dec_refs(indices[:prefix_len])
|
||||
self.req_to_token_pool.free(req_pool_idx)
|
||||
self.tree_cache.dec_ref_counter(req.last_node)
|
||||
|
||||
# unlock the last node
|
||||
self.tree_cache.dec_lock_ref(req.last_node)
|
||||
|
||||
# jump-forward
|
||||
req.jump_forward_and_retokenize(jump_forward_str, next_state)
|
||||
|
||||
@@ -5,7 +5,7 @@ import uvloop
|
||||
import zmq
|
||||
import zmq.asyncio
|
||||
|
||||
from sglang.srt.backend_config import GLOBAL_BACKEND_CONFIG
|
||||
from sglang import global_config
|
||||
from sglang.srt.managers.router.model_rpc import ModelRpcClient
|
||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||
from sglang.srt.utils import get_exception_traceback
|
||||
@@ -30,7 +30,7 @@ class RouterManager:
|
||||
self.recv_reqs = []
|
||||
|
||||
# Init some configs
|
||||
self.extend_dependency_time = GLOBAL_BACKEND_CONFIG.extend_dependency_time
|
||||
self.request_dependency_time = global_config.request_dependency_time
|
||||
|
||||
async def loop_for_forward(self):
|
||||
while True:
|
||||
@@ -46,9 +46,9 @@ class RouterManager:
|
||||
if len(out_pyobjs) != 0:
|
||||
has_finished = any([obj.finished for obj in out_pyobjs])
|
||||
if has_finished:
|
||||
if self.extend_dependency_time > 0:
|
||||
if self.request_dependency_time > 0:
|
||||
slept = True
|
||||
await asyncio.sleep(self.extend_dependency_time)
|
||||
await asyncio.sleep(self.request_dependency_time)
|
||||
|
||||
if not slept:
|
||||
await asyncio.sleep(0.0006)
|
||||
|
||||
@@ -117,7 +117,11 @@ class ModelRpcServer:
|
||||
logger.info(f"server_args: {server_args.print_mode_args()}")
|
||||
|
||||
# Init cache
|
||||
self.tree_cache = RadixCache(disable=server_args.disable_radix_cache)
|
||||
self.tree_cache = RadixCache(
|
||||
req_to_token_pool=self.model_runner.req_to_token_pool,
|
||||
token_to_kv_pool=self.model_runner.token_to_kv_pool,
|
||||
disable=server_args.disable_radix_cache,
|
||||
)
|
||||
self.tree_cache_metrics = {"total": 0, "hit": 0}
|
||||
self.scheduler = Scheduler(
|
||||
self.schedule_heuristic,
|
||||
@@ -203,6 +207,8 @@ class ModelRpcServer:
|
||||
# Run new fill batch
|
||||
self.forward_fill_batch(new_batch)
|
||||
|
||||
self.cache_filled_batch(new_batch)
|
||||
|
||||
if not new_batch.is_empty():
|
||||
if self.running_batch is None:
|
||||
self.running_batch = new_batch
|
||||
@@ -349,20 +355,19 @@ class ModelRpcServer:
|
||||
and req.extend_input_len + new_batch_input_tokens
|
||||
< self.max_prefill_num_token
|
||||
):
|
||||
delta = self.tree_cache.inc_ref_counter(req.last_node)
|
||||
delta = self.tree_cache.inc_lock_ref(req.last_node)
|
||||
available_size += delta
|
||||
|
||||
if not (
|
||||
req.extend_input_len + req.max_new_tokens() + new_batch_total_tokens
|
||||
< available_size
|
||||
):
|
||||
# Undo the insertion
|
||||
delta = self.tree_cache.dec_ref_counter(req.last_node)
|
||||
# Undo locking
|
||||
delta = self.tree_cache.dec_lock_ref(req.last_node)
|
||||
available_size += delta
|
||||
break
|
||||
else:
|
||||
# Add this request to the running batch
|
||||
self.token_to_kv_pool.add_refs(req.prefix_indices)
|
||||
can_run_list.append(req)
|
||||
new_batch_total_tokens += (
|
||||
req.extend_input_len + req.max_new_tokens()
|
||||
@@ -477,6 +482,18 @@ class ModelRpcServer:
|
||||
|
||||
self.handle_finished_requests(batch)
|
||||
|
||||
def cache_filled_batch(self, batch: Batch):
|
||||
req_pool_indices_cpu = batch.req_pool_indices.cpu().tolist()
|
||||
for i, req in enumerate(batch.reqs):
|
||||
new_prefix_indices, new_last_node = self.tree_cache.cache_req(
|
||||
token_ids=tuple(req.input_ids + req.output_ids)[:-1],
|
||||
last_uncached_pos=len(req.prefix_indices),
|
||||
req_pool_idx=req_pool_indices_cpu[i],
|
||||
del_in_memory_pool=False,
|
||||
old_last_node=req.last_node,
|
||||
)
|
||||
req.prefix_indices, req.last_node = new_prefix_indices, new_last_node
|
||||
|
||||
def forward_decode_batch(self, batch: Batch):
|
||||
# check if decode out of memory
|
||||
if not batch.check_decode_mem():
|
||||
@@ -636,17 +653,13 @@ class ModelRpcServer:
|
||||
req_pool_indices_cpu = batch.req_pool_indices.tolist()
|
||||
for i in finished_indices:
|
||||
req = batch.reqs[i]
|
||||
req_pool_idx = req_pool_indices_cpu[i]
|
||||
token_ids = tuple(req.input_ids + req.output_ids)
|
||||
seq_len = len(token_ids) - 1
|
||||
indices = self.req_to_token_pool.req_to_token[req_pool_idx, :seq_len]
|
||||
prefix_len = self.tree_cache.insert(
|
||||
token_ids[:seq_len], indices.clone()
|
||||
self.tree_cache.cache_req(
|
||||
token_ids=tuple(req.input_ids + req.output_ids)[:-1],
|
||||
last_uncached_pos=len(req.prefix_indices),
|
||||
req_pool_idx=req_pool_indices_cpu[i],
|
||||
)
|
||||
|
||||
self.token_to_kv_pool.dec_refs(indices[:prefix_len])
|
||||
self.req_to_token_pool.free(req_pool_idx)
|
||||
self.tree_cache.dec_ref_counter(req.last_node)
|
||||
self.tree_cache.dec_lock_ref(req.last_node)
|
||||
|
||||
# Update batch tensors
|
||||
if unfinished_indices:
|
||||
|
||||
@@ -11,7 +11,7 @@ class TreeNode:
|
||||
self.parent = None
|
||||
self.key = None
|
||||
self.value = None
|
||||
self.ref_counter = 0
|
||||
self.lock_ref = 0
|
||||
self.last_access_time = time.time()
|
||||
|
||||
def __lt__(self, other: "TreeNode"):
|
||||
@@ -28,7 +28,9 @@ def _key_match(key0, key1):
|
||||
|
||||
|
||||
class RadixCache:
|
||||
def __init__(self, disable: bool = False):
|
||||
def __init__(self, req_to_token_pool, token_to_kv_pool, disable: bool = False):
|
||||
self.req_to_token_pool = req_to_token_pool
|
||||
self.token_to_kv_pool = token_to_kv_pool
|
||||
self.disable = disable
|
||||
self.reset()
|
||||
|
||||
@@ -38,7 +40,7 @@ class RadixCache:
|
||||
self.root_node = TreeNode()
|
||||
self.root_node.key = []
|
||||
self.root_node.value = []
|
||||
self.root_node.ref_counter = 1
|
||||
self.root_node.lock_ref = 1
|
||||
self.evictable_size_ = 0
|
||||
|
||||
def match_prefix(self, key):
|
||||
@@ -50,6 +52,8 @@ class RadixCache:
|
||||
self._match_prefix_helper(self.root_node, key, value, last_node)
|
||||
if value:
|
||||
value = torch.concat(value)
|
||||
else:
|
||||
value = torch.tensor([], dtype=torch.int64)
|
||||
return value, last_node[0]
|
||||
|
||||
def insert(self, key, value=None):
|
||||
@@ -60,6 +64,34 @@ class RadixCache:
|
||||
value = [x for x in key]
|
||||
return self._insert_helper(self.root_node, key, value)
|
||||
|
||||
def cache_req(
|
||||
self,
|
||||
token_ids,
|
||||
last_uncached_pos,
|
||||
req_pool_idx,
|
||||
del_in_memory_pool=True,
|
||||
old_last_node=None,
|
||||
):
|
||||
# Insert the request into radix cache
|
||||
indices = self.req_to_token_pool.req_to_token[req_pool_idx, : len(token_ids)]
|
||||
new_prefix_len = self.insert(token_ids, indices.clone())
|
||||
|
||||
# Radix Cache takes one ref in memory pool
|
||||
self.token_to_kv_pool.dec_refs(indices[last_uncached_pos:new_prefix_len])
|
||||
|
||||
if del_in_memory_pool:
|
||||
self.req_to_token_pool.free(req_pool_idx)
|
||||
else:
|
||||
cached_indices, new_last_node = self.match_prefix(token_ids)
|
||||
assert len(cached_indices) == len(token_ids)
|
||||
|
||||
self.req_to_token_pool.req_to_token[
|
||||
req_pool_idx, last_uncached_pos : len(cached_indices)
|
||||
] = cached_indices[last_uncached_pos:]
|
||||
self.dec_lock_ref(old_last_node)
|
||||
self.inc_lock_ref(new_last_node)
|
||||
return cached_indices, new_last_node
|
||||
|
||||
def pretty_print(self):
|
||||
self._print_helper(self.root_node, 0)
|
||||
print(f"#tokens: {self.total_size()}")
|
||||
@@ -80,7 +112,7 @@ class RadixCache:
|
||||
|
||||
if x == self.root_node:
|
||||
break
|
||||
if x.ref_counter > 0:
|
||||
if x.lock_ref > 0:
|
||||
continue
|
||||
|
||||
num_evicted += evict_callback(x.value)
|
||||
@@ -89,23 +121,23 @@ class RadixCache:
|
||||
if len(x.parent.children) == 0:
|
||||
heapq.heappush(leaves, x.parent)
|
||||
|
||||
def inc_ref_counter(self, node):
|
||||
def inc_lock_ref(self, node: TreeNode):
|
||||
delta = 0
|
||||
while node != self.root_node:
|
||||
if node.ref_counter == 0:
|
||||
if node.lock_ref == 0:
|
||||
self.evictable_size_ -= len(node.value)
|
||||
delta -= len(node.value)
|
||||
node.ref_counter += 1
|
||||
node.lock_ref += 1
|
||||
node = node.parent
|
||||
return delta
|
||||
|
||||
def dec_ref_counter(self, node):
|
||||
def dec_lock_ref(self, node: TreeNode):
|
||||
delta = 0
|
||||
while node != self.root_node:
|
||||
if node.ref_counter == 1:
|
||||
if node.lock_ref == 1:
|
||||
self.evictable_size_ += len(node.value)
|
||||
delta += len(node.value)
|
||||
node.ref_counter -= 1
|
||||
node.lock_ref -= 1
|
||||
node = node.parent
|
||||
return delta
|
||||
|
||||
@@ -131,12 +163,12 @@ class RadixCache:
|
||||
last_node[0] = child
|
||||
self._match_prefix_helper(child, key[prefix_len:], value, last_node)
|
||||
|
||||
def _split_node(self, key, child, split_len):
|
||||
def _split_node(self, key, child: TreeNode, split_len):
|
||||
# new_node -> child
|
||||
new_node = TreeNode()
|
||||
new_node.children = {key[split_len:][0]: child}
|
||||
new_node.parent = child.parent
|
||||
new_node.ref_counter = child.ref_counter
|
||||
new_node.lock_ref = child.lock_ref
|
||||
new_node.key = child.key[:split_len]
|
||||
new_node.value = child.value[:split_len]
|
||||
child.parent = new_node
|
||||
@@ -176,11 +208,9 @@ class RadixCache:
|
||||
self.evictable_size_ += len(value)
|
||||
return 0
|
||||
|
||||
def _print_helper(self, node, indent):
|
||||
def _print_helper(self, node: TreeNode, indent):
|
||||
for _, child in node.children.items():
|
||||
print(
|
||||
" " * indent, len(child.key), child.key[:10], f"r={child.ref_counter}"
|
||||
)
|
||||
print(" " * indent, len(child.key), child.key[:10], f"r={child.lock_ref}")
|
||||
self._print_helper(child, indent=indent + 2)
|
||||
|
||||
def _delete_leaf(self, node):
|
||||
@@ -211,7 +241,7 @@ class RadixCache:
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tree = RadixCache()
|
||||
tree = RadixCache(None, None, False)
|
||||
|
||||
tree.insert("Hello")
|
||||
tree.insert("Hello")
|
||||
|
||||
@@ -27,44 +27,33 @@ class Scheduler:
|
||||
return forward_queue
|
||||
elif self.schedule_heuristic == "fcfs":
|
||||
return forward_queue
|
||||
elif self.schedule_heuristic == "weight":
|
||||
elif self.schedule_heuristic == "dfs-weight":
|
||||
last_node_to_reqs = defaultdict(list)
|
||||
for req in forward_queue:
|
||||
last_node_to_reqs[req.last_node].append(req)
|
||||
for node in last_node_to_reqs:
|
||||
last_node_to_reqs[node].sort(key=lambda x: -len(x.prefix_indices))
|
||||
|
||||
node_to_weight = defaultdict(int)
|
||||
self._calc_weight_recursive(
|
||||
self.tree_cache.root_node, last_node_to_reqs, node_to_weight
|
||||
)
|
||||
for node in last_node_to_reqs:
|
||||
node_to_weight[node] = len(last_node_to_reqs[node])
|
||||
self.calc_weight(self.tree_cache.root_node, node_to_weight)
|
||||
|
||||
tmp_queue = []
|
||||
self._get_weight_priority_recursive(
|
||||
self.tree_cache.root_node, node_to_weight, last_node_to_reqs, tmp_queue
|
||||
q = []
|
||||
self.get_dfs_priority(
|
||||
self.tree_cache.root_node, node_to_weight, last_node_to_reqs, q
|
||||
)
|
||||
assert len(tmp_queue) == len(forward_queue)
|
||||
return tmp_queue
|
||||
assert len(q) == len(forward_queue)
|
||||
return q
|
||||
else:
|
||||
raise ValueError(f"Unknown schedule_heuristic: {self.schedule_heuristic}")
|
||||
|
||||
def _calc_weight_recursive(self, cur_node, last_node_to_reqs, node_to_weight):
|
||||
node_to_weight[cur_node] = 1
|
||||
if cur_node in last_node_to_reqs:
|
||||
node_to_weight[cur_node] += len(last_node_to_reqs[cur_node])
|
||||
def calc_weight(self, cur_node, node_to_weight):
|
||||
for child in cur_node.children.values():
|
||||
self._calc_weight_recursive(child, last_node_to_reqs, node_to_weight)
|
||||
self.calc_weight(child, node_to_weight)
|
||||
node_to_weight[cur_node] += node_to_weight[child]
|
||||
|
||||
def _get_weight_priority_recursive(
|
||||
self, cur_node, node_to_wight, last_node_to_reqs, tmp_queue
|
||||
):
|
||||
visit_list = [child for child in cur_node.children.values()]
|
||||
visit_list.sort(key=lambda x: -node_to_wight[x])
|
||||
# for node in visit_list:
|
||||
# print(f"{node_to_wight[node]} {len(node.value) if node.value is not None else 0}")
|
||||
for child in visit_list:
|
||||
self._get_weight_priority_recursive(
|
||||
child, node_to_wight, last_node_to_reqs, tmp_queue
|
||||
)
|
||||
tmp_queue.extend(last_node_to_reqs[cur_node])
|
||||
def get_dfs_priority(self, cur_node, node_to_priority, last_node_to_reqs, q):
|
||||
childs = [child for child in cur_node.children.values()]
|
||||
childs.sort(key=lambda x: -node_to_priority[x])
|
||||
for child in childs:
|
||||
self.get_dfs_priority(child, node_to_priority, last_node_to_reqs, q)
|
||||
q.extend(last_node_to_reqs[cur_node])
|
||||
|
||||
@@ -149,7 +149,8 @@ class ServerArgs:
|
||||
"--schedule-heuristic",
|
||||
type=str,
|
||||
default=ServerArgs.schedule_heuristic,
|
||||
help="Schudule mode: [lpm, weight, random, fcfs]",
|
||||
choices=["lpm", "random", "fcfs", "dfs-weight"],
|
||||
help="Scheduling Heuristic.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--schedule-conservativeness",
|
||||
|
||||
Reference in New Issue
Block a user