[Core] in batch prefix caching by delay scheduling (#2442)

This commit is contained in:
SangBin Cho
2024-12-11 12:51:50 -08:00
committed by GitHub
parent 864bf2ba00
commit 9208618b3e
8 changed files with 87 additions and 16 deletions

View File

@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import Callable
from typing import Callable, List, Tuple
class BasePrefixCache(ABC):
@@ -10,7 +10,7 @@ class BasePrefixCache(ABC):
pass
@abstractmethod
def match_prefix(self, **kwargs):
def match_prefix(self, **kwargs) -> Tuple[List[int], int]:
pass
@abstractmethod

View File

@@ -2,7 +2,7 @@ from __future__ import annotations
"""Cache for chunked prefill, used when RadixCache is disabled."""
from typing import TYPE_CHECKING, Callable, List, Optional
from typing import TYPE_CHECKING, Callable, List, Optional, Tuple
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
@@ -30,7 +30,7 @@ class ChunkCache(BasePrefixCache):
def reset(self):
self.entries = {}
def match_prefix(self, rid: int, key: List[int]):
def match_prefix(self, rid: int, key: List[int]) -> Tuple[List[int], int]:
if rid not in self.entries:
return [], None

View File

@@ -22,7 +22,7 @@ The radix tree data structure for managing the KV cache.
import heapq
import time
from collections import defaultdict
from typing import TYPE_CHECKING, Callable, List, Optional
from typing import TYPE_CHECKING, Callable, List, Optional, Tuple
import torch
@@ -76,7 +76,17 @@ class RadixCache(BasePrefixCache):
self.root_node.lock_ref = 1
self.evictable_size_ = 0
def match_prefix(self, key: List, **kwargs):
def match_prefix(self, key: List[int], **kwargs) -> Tuple[torch.Tensor, int]:
"""Find the matching prefix from the radix tree.
Args:
key: A list of token IDs to find a matching prefix.
Returns:
A tuple of a tensor of matching prefix token IDs and
the last node that contains the prefix values. Note that
this API can modify the internal state of the Radix tree.
The last node create a new child if the prefix is shorter
than the last node's value.
"""
if self.disable:
return [], self.root_node