diff --git a/python/sglang/srt/mem_cache/hiradix_cache.py b/python/sglang/srt/mem_cache/hiradix_cache.py index d2010a531..24b2056e0 100644 --- a/python/sglang/srt/mem_cache/hiradix_cache.py +++ b/python/sglang/srt/mem_cache/hiradix_cache.py @@ -305,7 +305,7 @@ class HiRadixCache(RadixCache): if value: value = torch.cat(value) else: - value = torch.tensor([], dtype=torch.int32) + value = torch.tensor([], dtype=torch.int64) last_node_global = last_node while last_node.evicted: diff --git a/python/sglang/srt/mem_cache/memory_pool.py b/python/sglang/srt/mem_cache/memory_pool.py index bdc3ae844..28689268c 100644 --- a/python/sglang/srt/mem_cache/memory_pool.py +++ b/python/sglang/srt/mem_cache/memory_pool.py @@ -622,11 +622,10 @@ class HostKVCache(abc.ABC): self.mem_state = torch.zeros( (self.size,), dtype=torch.uint8, device=self.device ) - self.free_slots = torch.arange(self.size, dtype=torch.int32) - self.can_use_mem_size = self.size # A lock for synchronized operations on memory allocation and state transitions. self.lock = threading.RLock() + self.clear() @abc.abstractmethod def get_size_per_token(self): @@ -656,7 +655,7 @@ class HostKVCache(abc.ABC): def clear(self): self.mem_state.fill_(0) self.can_use_mem_size = self.size - self.free_slots = torch.arange(self.size, dtype=torch.int32) + self.free_slots = torch.arange(self.size, dtype=torch.int64) @synchronized def get_state(self, indices: torch.Tensor) -> MemoryStateInt: diff --git a/python/sglang/srt/mem_cache/radix_cache.py b/python/sglang/srt/mem_cache/radix_cache.py index 58ee432b9..dd608c154 100644 --- a/python/sglang/srt/mem_cache/radix_cache.py +++ b/python/sglang/srt/mem_cache/radix_cache.py @@ -140,7 +140,7 @@ class RadixCache(BasePrefixCache): return ( torch.empty( (0,), - dtype=torch.int32, + dtype=torch.int64, device=self.device, ), self.root_node, @@ -154,7 +154,7 @@ class RadixCache(BasePrefixCache): if value: value = torch.cat(value) else: - value = torch.empty((0,), dtype=torch.int32, device=self.device) + value = torch.empty((0,), dtype=torch.int64, device=self.device) return value, last_node def insert(self, key: List, value=None):