[Eagle] Refactor eagle speculative decoding (#3986)

Co-authored-by: Ke Bao <ISPObaoke@163.com>
This commit is contained in:
Ying Sheng
2025-03-05 08:06:07 -08:00
committed by GitHub
parent 5be8f1ed98
commit d3d4d76758
22 changed files with 670 additions and 352 deletions

View File

@@ -7,8 +7,8 @@ import torch
from sglang.srt.managers.cache_controller import HiCacheController
from sglang.srt.mem_cache.memory_pool import (
BaseTokenToKVPool,
MLATokenToKVPoolHost,
MHATokenToKVPool,
MHATokenToKVPoolHost,
ReqToTokenPool,
)
from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode, _key_match
@@ -21,9 +21,9 @@ class HiRadixCache(RadixCache):
def __init__(
self,
req_to_token_pool: ReqToTokenPool,
token_to_kv_pool: BaseTokenToKVPool,
token_to_kv_pool: MHATokenToKVPool,
):
self.token_to_kv_pool_host = MLATokenToKVPoolHost(token_to_kv_pool)
self.token_to_kv_pool_host = MHATokenToKVPoolHost(token_to_kv_pool)
self.cache_controller = HiCacheController(
token_to_kv_pool, self.token_to_kv_pool_host
)