[Eagle] Refactor eagle speculative decoding (#3986)

Co-authored-by: Ke Bao <ISPObaoke@163.com>
This commit is contained in:
Ying Sheng
2025-03-05 08:06:07 -08:00
committed by GitHub
parent 5be8f1ed98
commit d3d4d76758
22 changed files with 670 additions and 352 deletions

View File

@@ -1,13 +1,12 @@
from __future__ import annotations
"""Cache for chunked prefill, used when RadixCache is disabled."""
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple
import torch
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
if TYPE_CHECKING:
from sglang.srt.managers.schedule_batch import Req
@@ -21,11 +20,13 @@ class ChunkCacheEntry:
class ChunkCache(BasePrefixCache):
def __init__(
self, req_to_token_pool: ReqToTokenPool, token_to_kv_pool: BaseTokenToKVPool
self,
req_to_token_pool: ReqToTokenPool,
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
):
self.disable = True
self.req_to_token_pool = req_to_token_pool
self.token_to_kv_pool = token_to_kv_pool
self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
self.entries: Dict[str, ChunkCacheEntry] = {}
self.reset()
@@ -51,7 +52,7 @@ class ChunkCache(BasePrefixCache):
req.req_pool_idx, :token_id_len
]
self.req_to_token_pool.free(req.req_pool_idx)
self.token_to_kv_pool.free(kv_indices)
self.token_to_kv_pool_allocator.free(kv_indices)
if req.rid in self.entries:
del self.entries[req.rid]
@@ -91,3 +92,6 @@ class ChunkCache(BasePrefixCache):
def protected_size(self):
return 0
def pretty_print(self):
return ""

View File

@@ -7,8 +7,8 @@ import torch
from sglang.srt.managers.cache_controller import HiCacheController
from sglang.srt.mem_cache.memory_pool import (
BaseTokenToKVPool,
MLATokenToKVPoolHost,
MHATokenToKVPool,
MHATokenToKVPoolHost,
ReqToTokenPool,
)
from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode, _key_match
@@ -21,9 +21,9 @@ class HiRadixCache(RadixCache):
def __init__(
self,
req_to_token_pool: ReqToTokenPool,
token_to_kv_pool: BaseTokenToKVPool,
token_to_kv_pool: MHATokenToKVPool,
):
self.token_to_kv_pool_host = MLATokenToKVPoolHost(token_to_kv_pool)
self.token_to_kv_pool_host = MHATokenToKVPoolHost(token_to_kv_pool)
self.cache_controller = HiCacheController(
token_to_kv_pool, self.token_to_kv_pool_host
)

View File

@@ -20,9 +20,12 @@ Memory pool.
SGLang has two levels of memory pool.
ReqToTokenPool maps a a request to its token locations.
BaseTokenToKVPool maps a token location to its KV cache data.
TokenToKVPoolAllocator maps a token location to its KV cache data.
KVCache actually holds the physical kv cache. Allocation indices are allocated
by TokenToKVPoolAllocator
"""
import abc
import logging
import threading
from enum import IntEnum
@@ -89,7 +92,7 @@ class ReqToTokenPool:
self.free_slots = list(range(self.size))
class BaseTokenToKVPool:
class TokenToKVPoolAllocator:
"""A memory pool that maps a token location to its kv cache data."""
def __init__(
@@ -100,11 +103,6 @@ class BaseTokenToKVPool:
):
self.size = size
self.dtype = dtype
if dtype in (torch.float8_e5m2, torch.float8_e4m3fn):
# NOTE: Store as torch.uint8 because Tensor.index_put is not implemented for torch.float8_e5m2
self.store_dtype = torch.uint8
else:
self.store_dtype = dtype
self.device = device
self.free_slots = None
@@ -148,15 +146,22 @@ class BaseTokenToKVPool:
self.is_in_free_group = False
self.free_group = []
class KVCache(abc.ABC):
@abc.abstractmethod
def get_key_buffer(self, layer_id: int) -> torch.Tensor:
raise NotImplementedError()
@abc.abstractmethod
def get_value_buffer(self, layer_id: int) -> torch.Tensor:
raise NotImplementedError()
@abc.abstractmethod
def get_kv_buffer(self, layer_id: int) -> Tuple[torch.Tensor, torch.Tensor]:
raise NotImplementedError()
@abc.abstractmethod
def set_kv_buffer(
self,
layer: RadixAttention,
@@ -167,7 +172,7 @@ class BaseTokenToKVPool:
raise NotImplementedError()
class MHATokenToKVPool(BaseTokenToKVPool):
class MHATokenToKVPool(KVCache):
def __init__(
self,
@@ -179,8 +184,14 @@ class MHATokenToKVPool(BaseTokenToKVPool):
device: str,
enable_memory_saver: bool,
):
super().__init__(size, dtype, device)
self.size = size
self.dtype = dtype
self.device = device
if dtype in (torch.float8_e5m2, torch.float8_e4m3fn):
# NOTE: Store as torch.uint8 because Tensor.index_put is not implemented for torch.float8_e5m2
self.store_dtype = torch.uint8
else:
self.store_dtype = dtype
self.memory_saver_adapter = TorchMemorySaverAdapter.create(
enable=enable_memory_saver
)
@@ -297,7 +308,7 @@ def copy_two_array(loc, dst_1, src_1, dst_2, src_2, dtype, store_dtype):
dst_2[loc] = src_2.to(dtype).view(store_dtype)
class MLATokenToKVPool(BaseTokenToKVPool):
class MLATokenToKVPool(KVCache):
def __init__(
self,
size: int,
@@ -308,8 +319,14 @@ class MLATokenToKVPool(BaseTokenToKVPool):
device: str,
enable_memory_saver: bool,
):
super().__init__(size, dtype, device)
self.size = size
self.dtype = dtype
self.device = device
if dtype in (torch.float8_e5m2, torch.float8_e4m3fn):
# NOTE: Store as torch.uint8 because Tensor.index_put is not implemented for torch.float8_e5m2
self.store_dtype = torch.uint8
else:
self.store_dtype = dtype
self.kv_lora_rank = kv_lora_rank
memory_saver_adapter = TorchMemorySaverAdapter.create(
@@ -356,7 +373,7 @@ class MLATokenToKVPool(BaseTokenToKVPool):
self.kv_buffer[layer_id][loc] = cache_k
class DoubleSparseTokenToKVPool(BaseTokenToKVPool):
class DoubleSparseTokenToKVPool(KVCache):
def __init__(
self,
size: int,
@@ -368,8 +385,14 @@ class DoubleSparseTokenToKVPool(BaseTokenToKVPool):
heavy_channel_num: int,
enable_memory_saver: bool,
):
super().__init__(size, dtype, device)
self.size = size
self.dtype = dtype
self.device = device
if dtype in (torch.float8_e5m2, torch.float8_e4m3fn):
# NOTE: Store as torch.uint8 because Tensor.index_put is not implemented for torch.float8_e5m2
self.store_dtype = torch.uint8
else:
self.store_dtype = dtype
memory_saver_adapter = TorchMemorySaverAdapter.create(
enable=enable_memory_saver
)
@@ -437,12 +460,12 @@ def synchronized(func):
return wrapper
class MLATokenToKVPoolHost:
class MHATokenToKVPoolHost:
def __init__(
self,
device_pool: MHATokenToKVPool,
host_to_device_ratio: float = 4.0,
host_to_device_ratio: float = 2.0,
pin_memory: bool = False, # no need to use pin memory with the double buffering
device: str = "cpu",
):

View File

@@ -26,8 +26,9 @@ from typing import TYPE_CHECKING, Callable, List, Optional, Tuple
import torch
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
if TYPE_CHECKING:
from sglang.srt.managers.schedule_batch import Req
@@ -79,11 +80,11 @@ class RadixCache(BasePrefixCache):
def __init__(
self,
req_to_token_pool: ReqToTokenPool,
token_to_kv_pool: BaseTokenToKVPool,
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
disable: bool = False,
):
self.req_to_token_pool = req_to_token_pool
self.token_to_kv_pool = token_to_kv_pool
self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
self.disable = disable
self.reset()
@@ -139,7 +140,7 @@ class RadixCache(BasePrefixCache):
kv_indices = self.req_to_token_pool.req_to_token[
req.req_pool_idx, :token_ids_len
]
self.token_to_kv_pool.free(kv_indices)
self.token_to_kv_pool_allocator.free(kv_indices)
self.req_to_token_pool.free(req.req_pool_idx)
return
@@ -151,7 +152,9 @@ class RadixCache(BasePrefixCache):
# Radix Cache takes one ref in memory pool
new_prefix_len = self.insert(token_ids, kv_indices.clone())
self.token_to_kv_pool.free(kv_indices[len(req.prefix_indices) : new_prefix_len])
self.token_to_kv_pool_allocator.free(
kv_indices[len(req.prefix_indices) : new_prefix_len]
)
# Remove req slot release the cache lock
self.req_to_token_pool.free(req.req_pool_idx)
@@ -171,7 +174,9 @@ class RadixCache(BasePrefixCache):
# Radix Cache takes one ref in memory pool
new_prefix_len = self.insert(token_ids, kv_indices.clone())
self.token_to_kv_pool.free(kv_indices[len(req.prefix_indices) : new_prefix_len])
self.token_to_kv_pool_allocator.free(
kv_indices[len(req.prefix_indices) : new_prefix_len]
)
# The prefix indices could be updated, reuse it
new_indices, new_last_node = self.match_prefix(token_ids)