From f5bbf6037dd87e4e6d5913a7440ec229f11cf39e Mon Sep 17 00:00:00 2001 From: Zhiqiang Xie Date: Sun, 16 Mar 2025 18:14:27 -0700 Subject: [PATCH] Fix: Complete int32 to int64 conversion (#4465) --- python/sglang/srt/mem_cache/hiradix_cache.py | 2 +- python/sglang/srt/mem_cache/memory_pool.py | 5 ++--- python/sglang/srt/mem_cache/radix_cache.py | 4 ++-- 3 files changed, 5 insertions(+), 6 deletions(-) diff --git a/python/sglang/srt/mem_cache/hiradix_cache.py b/python/sglang/srt/mem_cache/hiradix_cache.py index d2010a531..24b2056e0 100644 --- a/python/sglang/srt/mem_cache/hiradix_cache.py +++ b/python/sglang/srt/mem_cache/hiradix_cache.py @@ -305,7 +305,7 @@ class HiRadixCache(RadixCache): if value: value = torch.cat(value) else: - value = torch.tensor([], dtype=torch.int32) + value = torch.tensor([], dtype=torch.int64) last_node_global = last_node while last_node.evicted: diff --git a/python/sglang/srt/mem_cache/memory_pool.py b/python/sglang/srt/mem_cache/memory_pool.py index bdc3ae844..28689268c 100644 --- a/python/sglang/srt/mem_cache/memory_pool.py +++ b/python/sglang/srt/mem_cache/memory_pool.py @@ -622,11 +622,10 @@ class HostKVCache(abc.ABC): self.mem_state = torch.zeros( (self.size,), dtype=torch.uint8, device=self.device ) - self.free_slots = torch.arange(self.size, dtype=torch.int32) - self.can_use_mem_size = self.size # A lock for synchronized operations on memory allocation and state transitions. self.lock = threading.RLock() + self.clear() @abc.abstractmethod def get_size_per_token(self): @@ -656,7 +655,7 @@ class HostKVCache(abc.ABC): def clear(self): self.mem_state.fill_(0) self.can_use_mem_size = self.size - self.free_slots = torch.arange(self.size, dtype=torch.int32) + self.free_slots = torch.arange(self.size, dtype=torch.int64) @synchronized def get_state(self, indices: torch.Tensor) -> MemoryStateInt: diff --git a/python/sglang/srt/mem_cache/radix_cache.py b/python/sglang/srt/mem_cache/radix_cache.py index 58ee432b9..dd608c154 100644 --- a/python/sglang/srt/mem_cache/radix_cache.py +++ b/python/sglang/srt/mem_cache/radix_cache.py @@ -140,7 +140,7 @@ class RadixCache(BasePrefixCache): return ( torch.empty( (0,), - dtype=torch.int32, + dtype=torch.int64, device=self.device, ), self.root_node, @@ -154,7 +154,7 @@ class RadixCache(BasePrefixCache): if value: value = torch.cat(value) else: - value = torch.empty((0,), dtype=torch.int32, device=self.device) + value = torch.empty((0,), dtype=torch.int64, device=self.device) return value, last_node def insert(self, key: List, value=None):