Memory pool fix for upstream change about eagle (#4170)

This commit is contained in:
Zhiqiang Xie
2025-03-07 00:58:20 -08:00
committed by GitHub
parent 94a2b9d33e
commit 9376ac361d
4 changed files with 27 additions and 27 deletions

View File

@@ -22,7 +22,10 @@ from typing import List, Optional
import torch
from sglang.srt.mem_cache.memory_pool import MHATokenToKVPool, MHATokenToKVPoolHost
from sglang.srt.mem_cache.memory_pool import (
MHATokenToKVPoolHost,
TokenToKVPoolAllocator,
)
logger = logging.getLogger(__name__)
@@ -127,12 +130,12 @@ class HiCacheController:
def __init__(
self,
mem_pool_device: MHATokenToKVPool,
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
mem_pool_host: MHATokenToKVPoolHost,
write_policy: str = "write_through_selective",
):
self.mem_pool_device = mem_pool_device
self.mem_pool_device_allocator = token_to_kv_pool_allocator
self.mem_pool_device = token_to_kv_pool_allocator.get_kvcache()
self.mem_pool_host = mem_pool_host
self.write_policy = write_policy
@@ -216,7 +219,7 @@ class HiCacheController:
"""
Load KV caches from host memory to device memory.
"""
device_indices = self.mem_pool_device.alloc(len(host_indices))
device_indices = self.mem_pool_device_allocator.alloc(len(host_indices))
if device_indices is None:
return None
self.mem_pool_host.protect_load(host_indices)
@@ -417,7 +420,7 @@ class HiCacheController:
self, device_indices: torch.Tensor, host_indices: torch.Tensor
) -> int:
if self.mem_pool_host.is_synced(host_indices):
self.mem_pool_device.free(device_indices)
self.mem_pool_device_allocator.free(device_indices)
self.mem_pool_host.update_backup(host_indices)
return len(device_indices)
else: