[Core] in batch prefix caching by delay scheduling (#2442)
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Callable
|
||||
from typing import Callable, List, Tuple
|
||||
|
||||
|
||||
class BasePrefixCache(ABC):
|
||||
@@ -10,7 +10,7 @@ class BasePrefixCache(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def match_prefix(self, **kwargs):
|
||||
def match_prefix(self, **kwargs) -> Tuple[List[int], int]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
|
||||
@@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
|
||||
"""Cache for chunked prefill, used when RadixCache is disabled."""
|
||||
|
||||
from typing import TYPE_CHECKING, Callable, List, Optional
|
||||
from typing import TYPE_CHECKING, Callable, List, Optional, Tuple
|
||||
|
||||
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
||||
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
|
||||
@@ -30,7 +30,7 @@ class ChunkCache(BasePrefixCache):
|
||||
def reset(self):
|
||||
self.entries = {}
|
||||
|
||||
def match_prefix(self, rid: int, key: List[int]):
|
||||
def match_prefix(self, rid: int, key: List[int]) -> Tuple[List[int], int]:
|
||||
if rid not in self.entries:
|
||||
return [], None
|
||||
|
||||
|
||||
@@ -22,7 +22,7 @@ The radix tree data structure for managing the KV cache.
|
||||
import heapq
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from typing import TYPE_CHECKING, Callable, List, Optional
|
||||
from typing import TYPE_CHECKING, Callable, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
@@ -76,7 +76,17 @@ class RadixCache(BasePrefixCache):
|
||||
self.root_node.lock_ref = 1
|
||||
self.evictable_size_ = 0
|
||||
|
||||
def match_prefix(self, key: List, **kwargs):
|
||||
def match_prefix(self, key: List[int], **kwargs) -> Tuple[torch.Tensor, int]:
|
||||
"""Find the matching prefix from the radix tree.
|
||||
Args:
|
||||
key: A list of token IDs to find a matching prefix.
|
||||
Returns:
|
||||
A tuple of a tensor of matching prefix token IDs and
|
||||
the last node that contains the prefix values. Note that
|
||||
this API can modify the internal state of the Radix tree.
|
||||
The last node create a new child if the prefix is shorter
|
||||
than the last node's value.
|
||||
"""
|
||||
if self.disable:
|
||||
return [], self.root_node
|
||||
|
||||
|
||||
Reference in New Issue
Block a user