Memory pool fix for upstream change about eagle (#4170)
This commit is contained in:
@@ -22,7 +22,10 @@ from typing import List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.mem_cache.memory_pool import MHATokenToKVPool, MHATokenToKVPoolHost
|
||||
from sglang.srt.mem_cache.memory_pool import (
|
||||
MHATokenToKVPoolHost,
|
||||
TokenToKVPoolAllocator,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -127,12 +130,12 @@ class HiCacheController:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
mem_pool_device: MHATokenToKVPool,
|
||||
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
|
||||
mem_pool_host: MHATokenToKVPoolHost,
|
||||
write_policy: str = "write_through_selective",
|
||||
):
|
||||
|
||||
self.mem_pool_device = mem_pool_device
|
||||
self.mem_pool_device_allocator = token_to_kv_pool_allocator
|
||||
self.mem_pool_device = token_to_kv_pool_allocator.get_kvcache()
|
||||
self.mem_pool_host = mem_pool_host
|
||||
self.write_policy = write_policy
|
||||
|
||||
@@ -216,7 +219,7 @@ class HiCacheController:
|
||||
"""
|
||||
Load KV caches from host memory to device memory.
|
||||
"""
|
||||
device_indices = self.mem_pool_device.alloc(len(host_indices))
|
||||
device_indices = self.mem_pool_device_allocator.alloc(len(host_indices))
|
||||
if device_indices is None:
|
||||
return None
|
||||
self.mem_pool_host.protect_load(host_indices)
|
||||
@@ -417,7 +420,7 @@ class HiCacheController:
|
||||
self, device_indices: torch.Tensor, host_indices: torch.Tensor
|
||||
) -> int:
|
||||
if self.mem_pool_host.is_synced(host_indices):
|
||||
self.mem_pool_device.free(device_indices)
|
||||
self.mem_pool_device_allocator.free(device_indices)
|
||||
self.mem_pool_host.update_backup(host_indices)
|
||||
return len(device_indices)
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user