Improve type annotation (#1029)
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Callable
|
||||
|
||||
|
||||
class BasePrefixCache(ABC):
|
||||
@@ -25,7 +26,7 @@ class BasePrefixCache(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def evict(self, num_tokens, evict_callback):
|
||||
def evict(self, num_tokens: int, evict_callback: Callable):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
@@ -41,7 +42,7 @@ class BasePrefixCache(ABC):
|
||||
pass
|
||||
|
||||
def total_size(self):
|
||||
raise NotImplementedError
|
||||
raise NotImplementedError()
|
||||
|
||||
def pretty_print(self):
|
||||
raise NotImplementedError
|
||||
raise NotImplementedError()
|
||||
|
||||
@@ -1,8 +1,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
"""Cache for chunked prefill, used when RadixCache is disabled."""
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Callable
|
||||
|
||||
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
||||
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.managers.schedule_batch import Req
|
||||
@@ -15,7 +18,9 @@ class ChunkCacheEntry:
|
||||
|
||||
|
||||
class ChunkCache(BasePrefixCache):
|
||||
def __init__(self, req_to_token_pool, token_to_kv_pool):
|
||||
def __init__(
|
||||
self, req_to_token_pool: ReqToTokenPool, token_to_kv_pool: BaseTokenToKVPool
|
||||
):
|
||||
self.disable = True
|
||||
self.req_to_token_pool = req_to_token_pool
|
||||
self.token_to_kv_pool = token_to_kv_pool
|
||||
@@ -32,7 +37,7 @@ class ChunkCache(BasePrefixCache):
|
||||
entry = self.entries[rid]
|
||||
return entry.value, entry
|
||||
|
||||
def cache_finished_req(self, req: "Req", token_ids=None):
|
||||
def cache_finished_req(self, req: Req, token_ids: Optional[List[int]] = None):
|
||||
if token_ids is None:
|
||||
token_ids = (req.origin_input_ids + req.output_ids)[:-1]
|
||||
|
||||
@@ -45,7 +50,7 @@ class ChunkCache(BasePrefixCache):
|
||||
if req.rid in self.entries:
|
||||
del self.entries[req.rid]
|
||||
|
||||
def cache_unfinished_req(self, req: "Req", token_ids=None):
|
||||
def cache_unfinished_req(self, req: Req, token_ids: Optional[List[int]] = None):
|
||||
if token_ids is None:
|
||||
token_ids = req.fill_ids
|
||||
|
||||
@@ -64,7 +69,7 @@ class ChunkCache(BasePrefixCache):
|
||||
def insert(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def evict(self, num_tokens, evict_callback):
|
||||
def evict(self, num_tokens: int, evict_callback: Callable):
|
||||
pass
|
||||
|
||||
def inc_lock_ref(self, node):
|
||||
|
||||
@@ -16,7 +16,7 @@ limitations under the License.
|
||||
"""Memory pool."""
|
||||
|
||||
import logging
|
||||
from typing import List
|
||||
from typing import List, Union
|
||||
|
||||
import torch
|
||||
|
||||
@@ -42,7 +42,7 @@ class ReqToTokenPool:
|
||||
|
||||
return select_index
|
||||
|
||||
def free(self, free_index):
|
||||
def free(self, free_index: Union[int, List[int]]):
|
||||
if isinstance(free_index, (int,)):
|
||||
self.free_slots.append(free_index)
|
||||
else:
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
"""
|
||||
Copyright 2023-2024 SGLang Team
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@@ -25,6 +27,7 @@ from typing import TYPE_CHECKING
|
||||
import torch
|
||||
|
||||
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
||||
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.managers.schedule_batch import Req
|
||||
@@ -43,7 +46,7 @@ class TreeNode:
|
||||
return self.last_access_time < other.last_access_time
|
||||
|
||||
|
||||
def _key_match(key0, key1):
|
||||
def _key_match(key0: List, key1: List):
|
||||
i = 0
|
||||
for k0, k1 in zip(key0, key1):
|
||||
if k0 != k1:
|
||||
@@ -53,7 +56,12 @@ def _key_match(key0, key1):
|
||||
|
||||
|
||||
class RadixCache(BasePrefixCache):
|
||||
def __init__(self, req_to_token_pool, token_to_kv_pool, disable: bool = False):
|
||||
def __init__(
|
||||
self,
|
||||
req_to_token_pool: ReqToTokenPool,
|
||||
token_to_kv_pool: BaseTokenToKVPool,
|
||||
disable: bool = False,
|
||||
):
|
||||
self.req_to_token_pool = req_to_token_pool
|
||||
self.token_to_kv_pool = token_to_kv_pool
|
||||
self.disable = disable
|
||||
@@ -68,7 +76,7 @@ class RadixCache(BasePrefixCache):
|
||||
self.root_node.lock_ref = 1
|
||||
self.evictable_size_ = 0
|
||||
|
||||
def match_prefix(self, key, **kwargs):
|
||||
def match_prefix(self, key: List, **kwargs):
|
||||
if self.disable:
|
||||
return [], self.root_node
|
||||
|
||||
@@ -81,7 +89,7 @@ class RadixCache(BasePrefixCache):
|
||||
value = torch.tensor([], dtype=torch.int32)
|
||||
return value, last_node[0]
|
||||
|
||||
def insert(self, key, value=None):
|
||||
def insert(self, key: List, value=None):
|
||||
if self.disable:
|
||||
return 0
|
||||
|
||||
@@ -89,7 +97,7 @@ class RadixCache(BasePrefixCache):
|
||||
value = [x for x in key]
|
||||
return self._insert_helper(self.root_node, key, value)
|
||||
|
||||
def cache_finished_req(self, req: "Req", token_ids=None):
|
||||
def cache_finished_req(self, req: Req, token_ids: Optional[List[int]] = None):
|
||||
"""Cache request when it finishes."""
|
||||
if token_ids is None:
|
||||
token_ids = (req.origin_input_ids + req.output_ids)[:-1]
|
||||
@@ -110,7 +118,7 @@ class RadixCache(BasePrefixCache):
|
||||
self.req_to_token_pool.free(req.req_pool_idx)
|
||||
self.dec_lock_ref(req.last_node)
|
||||
|
||||
def cache_unfinished_req(self, req: "Req", token_ids=None):
|
||||
def cache_unfinished_req(self, req: Req, token_ids: Optional[List[int]] = None):
|
||||
"""Cache request when it is unfinished."""
|
||||
if self.disable:
|
||||
return
|
||||
@@ -145,7 +153,7 @@ class RadixCache(BasePrefixCache):
|
||||
def total_size(self):
|
||||
return self._total_size_helper(self.root_node)
|
||||
|
||||
def evict(self, num_tokens, evict_callback):
|
||||
def evict(self, num_tokens: int, evict_callback: Callable):
|
||||
if self.disable:
|
||||
return
|
||||
|
||||
@@ -199,7 +207,9 @@ class RadixCache(BasePrefixCache):
|
||||
|
||||
##### Internal Helper Functions #####
|
||||
|
||||
def _match_prefix_helper(self, node, key, value, last_node):
|
||||
def _match_prefix_helper(
|
||||
self, node: TreeNode, key: List, value, last_node: TreeNode
|
||||
):
|
||||
node.last_access_time = time.time()
|
||||
if len(key) == 0:
|
||||
return
|
||||
@@ -216,7 +226,7 @@ class RadixCache(BasePrefixCache):
|
||||
last_node[0] = child
|
||||
self._match_prefix_helper(child, key[prefix_len:], value, last_node)
|
||||
|
||||
def _split_node(self, key, child: TreeNode, split_len):
|
||||
def _split_node(self, key, child: TreeNode, split_len: int):
|
||||
# new_node -> child
|
||||
new_node = TreeNode()
|
||||
new_node.children = {key[split_len:][0]: child}
|
||||
@@ -230,7 +240,7 @@ class RadixCache(BasePrefixCache):
|
||||
new_node.parent.children[key[:split_len][0]] = new_node
|
||||
return new_node
|
||||
|
||||
def _insert_helper(self, node, key, value):
|
||||
def _insert_helper(self, node: TreeNode, key: List, value):
|
||||
node.last_access_time = time.time()
|
||||
if len(key) == 0:
|
||||
return 0
|
||||
@@ -261,7 +271,7 @@ class RadixCache(BasePrefixCache):
|
||||
self.evictable_size_ += len(value)
|
||||
return 0
|
||||
|
||||
def _print_helper(self, node: TreeNode, indent):
|
||||
def _print_helper(self, node: TreeNode, indent: int):
|
||||
for _, child in node.children.items():
|
||||
print(" " * indent, len(child.key), child.key[:10], f"r={child.lock_ref}")
|
||||
self._print_helper(child, indent=indent + 2)
|
||||
@@ -273,7 +283,7 @@ class RadixCache(BasePrefixCache):
|
||||
del node.parent.children[k]
|
||||
self.evictable_size_ -= len(node.key)
|
||||
|
||||
def _total_size_helper(self, node):
|
||||
def _total_size_helper(self, node: TreeNode):
|
||||
x = len(node.value)
|
||||
for child in node.children.values():
|
||||
x += self._total_size_helper(child)
|
||||
|
||||
Reference in New Issue
Block a user