[Eagle] Refactor eagle speculative decoding (#3986)
Co-authored-by: Ke Bao <ISPObaoke@163.com>
This commit is contained in:
@@ -1,13 +1,12 @@
|
||||
from __future__ import annotations
|
||||
|
||||
"""Cache for chunked prefill, used when RadixCache is disabled."""
|
||||
|
||||
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
||||
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
|
||||
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.managers.schedule_batch import Req
|
||||
@@ -21,11 +20,13 @@ class ChunkCacheEntry:
|
||||
|
||||
class ChunkCache(BasePrefixCache):
|
||||
def __init__(
|
||||
self, req_to_token_pool: ReqToTokenPool, token_to_kv_pool: BaseTokenToKVPool
|
||||
self,
|
||||
req_to_token_pool: ReqToTokenPool,
|
||||
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
|
||||
):
|
||||
self.disable = True
|
||||
self.req_to_token_pool = req_to_token_pool
|
||||
self.token_to_kv_pool = token_to_kv_pool
|
||||
self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
|
||||
self.entries: Dict[str, ChunkCacheEntry] = {}
|
||||
|
||||
self.reset()
|
||||
@@ -51,7 +52,7 @@ class ChunkCache(BasePrefixCache):
|
||||
req.req_pool_idx, :token_id_len
|
||||
]
|
||||
self.req_to_token_pool.free(req.req_pool_idx)
|
||||
self.token_to_kv_pool.free(kv_indices)
|
||||
self.token_to_kv_pool_allocator.free(kv_indices)
|
||||
|
||||
if req.rid in self.entries:
|
||||
del self.entries[req.rid]
|
||||
@@ -91,3 +92,6 @@ class ChunkCache(BasePrefixCache):
|
||||
|
||||
def protected_size(self):
|
||||
return 0
|
||||
|
||||
def pretty_print(self):
|
||||
return ""
|
||||
|
||||
@@ -7,8 +7,8 @@ import torch
|
||||
|
||||
from sglang.srt.managers.cache_controller import HiCacheController
|
||||
from sglang.srt.mem_cache.memory_pool import (
|
||||
BaseTokenToKVPool,
|
||||
MLATokenToKVPoolHost,
|
||||
MHATokenToKVPool,
|
||||
MHATokenToKVPoolHost,
|
||||
ReqToTokenPool,
|
||||
)
|
||||
from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode, _key_match
|
||||
@@ -21,9 +21,9 @@ class HiRadixCache(RadixCache):
|
||||
def __init__(
|
||||
self,
|
||||
req_to_token_pool: ReqToTokenPool,
|
||||
token_to_kv_pool: BaseTokenToKVPool,
|
||||
token_to_kv_pool: MHATokenToKVPool,
|
||||
):
|
||||
self.token_to_kv_pool_host = MLATokenToKVPoolHost(token_to_kv_pool)
|
||||
self.token_to_kv_pool_host = MHATokenToKVPoolHost(token_to_kv_pool)
|
||||
self.cache_controller = HiCacheController(
|
||||
token_to_kv_pool, self.token_to_kv_pool_host
|
||||
)
|
||||
|
||||
@@ -20,9 +20,12 @@ Memory pool.
|
||||
|
||||
SGLang has two levels of memory pool.
|
||||
ReqToTokenPool maps a a request to its token locations.
|
||||
BaseTokenToKVPool maps a token location to its KV cache data.
|
||||
TokenToKVPoolAllocator maps a token location to its KV cache data.
|
||||
KVCache actually holds the physical kv cache. Allocation indices are allocated
|
||||
by TokenToKVPoolAllocator
|
||||
"""
|
||||
|
||||
import abc
|
||||
import logging
|
||||
import threading
|
||||
from enum import IntEnum
|
||||
@@ -89,7 +92,7 @@ class ReqToTokenPool:
|
||||
self.free_slots = list(range(self.size))
|
||||
|
||||
|
||||
class BaseTokenToKVPool:
|
||||
class TokenToKVPoolAllocator:
|
||||
"""A memory pool that maps a token location to its kv cache data."""
|
||||
|
||||
def __init__(
|
||||
@@ -100,11 +103,6 @@ class BaseTokenToKVPool:
|
||||
):
|
||||
self.size = size
|
||||
self.dtype = dtype
|
||||
if dtype in (torch.float8_e5m2, torch.float8_e4m3fn):
|
||||
# NOTE: Store as torch.uint8 because Tensor.index_put is not implemented for torch.float8_e5m2
|
||||
self.store_dtype = torch.uint8
|
||||
else:
|
||||
self.store_dtype = dtype
|
||||
self.device = device
|
||||
|
||||
self.free_slots = None
|
||||
@@ -148,15 +146,22 @@ class BaseTokenToKVPool:
|
||||
self.is_in_free_group = False
|
||||
self.free_group = []
|
||||
|
||||
|
||||
class KVCache(abc.ABC):
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_key_buffer(self, layer_id: int) -> torch.Tensor:
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_value_buffer(self, layer_id: int) -> torch.Tensor:
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_kv_buffer(self, layer_id: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def set_kv_buffer(
|
||||
self,
|
||||
layer: RadixAttention,
|
||||
@@ -167,7 +172,7 @@ class BaseTokenToKVPool:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class MHATokenToKVPool(BaseTokenToKVPool):
|
||||
class MHATokenToKVPool(KVCache):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -179,8 +184,14 @@ class MHATokenToKVPool(BaseTokenToKVPool):
|
||||
device: str,
|
||||
enable_memory_saver: bool,
|
||||
):
|
||||
super().__init__(size, dtype, device)
|
||||
|
||||
self.size = size
|
||||
self.dtype = dtype
|
||||
self.device = device
|
||||
if dtype in (torch.float8_e5m2, torch.float8_e4m3fn):
|
||||
# NOTE: Store as torch.uint8 because Tensor.index_put is not implemented for torch.float8_e5m2
|
||||
self.store_dtype = torch.uint8
|
||||
else:
|
||||
self.store_dtype = dtype
|
||||
self.memory_saver_adapter = TorchMemorySaverAdapter.create(
|
||||
enable=enable_memory_saver
|
||||
)
|
||||
@@ -297,7 +308,7 @@ def copy_two_array(loc, dst_1, src_1, dst_2, src_2, dtype, store_dtype):
|
||||
dst_2[loc] = src_2.to(dtype).view(store_dtype)
|
||||
|
||||
|
||||
class MLATokenToKVPool(BaseTokenToKVPool):
|
||||
class MLATokenToKVPool(KVCache):
|
||||
def __init__(
|
||||
self,
|
||||
size: int,
|
||||
@@ -308,8 +319,14 @@ class MLATokenToKVPool(BaseTokenToKVPool):
|
||||
device: str,
|
||||
enable_memory_saver: bool,
|
||||
):
|
||||
super().__init__(size, dtype, device)
|
||||
|
||||
self.size = size
|
||||
self.dtype = dtype
|
||||
self.device = device
|
||||
if dtype in (torch.float8_e5m2, torch.float8_e4m3fn):
|
||||
# NOTE: Store as torch.uint8 because Tensor.index_put is not implemented for torch.float8_e5m2
|
||||
self.store_dtype = torch.uint8
|
||||
else:
|
||||
self.store_dtype = dtype
|
||||
self.kv_lora_rank = kv_lora_rank
|
||||
|
||||
memory_saver_adapter = TorchMemorySaverAdapter.create(
|
||||
@@ -356,7 +373,7 @@ class MLATokenToKVPool(BaseTokenToKVPool):
|
||||
self.kv_buffer[layer_id][loc] = cache_k
|
||||
|
||||
|
||||
class DoubleSparseTokenToKVPool(BaseTokenToKVPool):
|
||||
class DoubleSparseTokenToKVPool(KVCache):
|
||||
def __init__(
|
||||
self,
|
||||
size: int,
|
||||
@@ -368,8 +385,14 @@ class DoubleSparseTokenToKVPool(BaseTokenToKVPool):
|
||||
heavy_channel_num: int,
|
||||
enable_memory_saver: bool,
|
||||
):
|
||||
super().__init__(size, dtype, device)
|
||||
|
||||
self.size = size
|
||||
self.dtype = dtype
|
||||
self.device = device
|
||||
if dtype in (torch.float8_e5m2, torch.float8_e4m3fn):
|
||||
# NOTE: Store as torch.uint8 because Tensor.index_put is not implemented for torch.float8_e5m2
|
||||
self.store_dtype = torch.uint8
|
||||
else:
|
||||
self.store_dtype = dtype
|
||||
memory_saver_adapter = TorchMemorySaverAdapter.create(
|
||||
enable=enable_memory_saver
|
||||
)
|
||||
@@ -437,12 +460,12 @@ def synchronized(func):
|
||||
return wrapper
|
||||
|
||||
|
||||
class MLATokenToKVPoolHost:
|
||||
class MHATokenToKVPoolHost:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
device_pool: MHATokenToKVPool,
|
||||
host_to_device_ratio: float = 4.0,
|
||||
host_to_device_ratio: float = 2.0,
|
||||
pin_memory: bool = False, # no need to use pin memory with the double buffering
|
||||
device: str = "cpu",
|
||||
):
|
||||
|
||||
@@ -26,8 +26,9 @@ from typing import TYPE_CHECKING, Callable, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
||||
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
|
||||
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.managers.schedule_batch import Req
|
||||
@@ -79,11 +80,11 @@ class RadixCache(BasePrefixCache):
|
||||
def __init__(
|
||||
self,
|
||||
req_to_token_pool: ReqToTokenPool,
|
||||
token_to_kv_pool: BaseTokenToKVPool,
|
||||
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
|
||||
disable: bool = False,
|
||||
):
|
||||
self.req_to_token_pool = req_to_token_pool
|
||||
self.token_to_kv_pool = token_to_kv_pool
|
||||
self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
|
||||
self.disable = disable
|
||||
self.reset()
|
||||
|
||||
@@ -139,7 +140,7 @@ class RadixCache(BasePrefixCache):
|
||||
kv_indices = self.req_to_token_pool.req_to_token[
|
||||
req.req_pool_idx, :token_ids_len
|
||||
]
|
||||
self.token_to_kv_pool.free(kv_indices)
|
||||
self.token_to_kv_pool_allocator.free(kv_indices)
|
||||
self.req_to_token_pool.free(req.req_pool_idx)
|
||||
return
|
||||
|
||||
@@ -151,7 +152,9 @@ class RadixCache(BasePrefixCache):
|
||||
|
||||
# Radix Cache takes one ref in memory pool
|
||||
new_prefix_len = self.insert(token_ids, kv_indices.clone())
|
||||
self.token_to_kv_pool.free(kv_indices[len(req.prefix_indices) : new_prefix_len])
|
||||
self.token_to_kv_pool_allocator.free(
|
||||
kv_indices[len(req.prefix_indices) : new_prefix_len]
|
||||
)
|
||||
|
||||
# Remove req slot release the cache lock
|
||||
self.req_to_token_pool.free(req.req_pool_idx)
|
||||
@@ -171,7 +174,9 @@ class RadixCache(BasePrefixCache):
|
||||
|
||||
# Radix Cache takes one ref in memory pool
|
||||
new_prefix_len = self.insert(token_ids, kv_indices.clone())
|
||||
self.token_to_kv_pool.free(kv_indices[len(req.prefix_indices) : new_prefix_len])
|
||||
self.token_to_kv_pool_allocator.free(
|
||||
kv_indices[len(req.prefix_indices) : new_prefix_len]
|
||||
)
|
||||
|
||||
# The prefix indices could be updated, reuse it
|
||||
new_indices, new_last_node = self.match_prefix(token_ids)
|
||||
|
||||
Reference in New Issue
Block a user