From 6db27f7b3b883ab47216114bd611b4f628bdfaa2 Mon Sep 17 00:00:00 2001 From: Zhiqiang Xie Date: Wed, 7 Aug 2024 13:40:07 -0700 Subject: [PATCH] misc: correct the int data type for token ids and indices (#969) --- python/sglang/srt/managers/schedule_batch.py | 2 +- python/sglang/srt/mem_cache/radix_cache.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 205f31797..bd2158789 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -780,7 +780,7 @@ def top_k_top_p_sampling_from_probs_torch( sampled_index = torch.multinomial(probs_sort, num_samples=1) except RuntimeError: batch_next_token_ids = torch.zeros( - (probs_sort.shape[0],), dtype=torch.int64, device=probs.device + (probs_sort.shape[0],), dtype=torch.int32, device=probs.device ) success = torch.zeros(probs.shape[0], dtype=torch.bool, device=probs.device) return batch_next_token_ids, success diff --git a/python/sglang/srt/mem_cache/radix_cache.py b/python/sglang/srt/mem_cache/radix_cache.py index c6fc3191b..347ae002e 100644 --- a/python/sglang/srt/mem_cache/radix_cache.py +++ b/python/sglang/srt/mem_cache/radix_cache.py @@ -74,7 +74,7 @@ class RadixCache(BasePrefixCache): if value: value = torch.concat(value) else: - value = torch.tensor([], dtype=torch.int64) + value = torch.tensor([], dtype=torch.int32) return value, last_node[0] def insert(self, key, value=None): @@ -102,7 +102,7 @@ class RadixCache(BasePrefixCache): if del_in_memory_pool: self.token_to_kv_pool.free(indices) else: - return torch.tensor([], dtype=torch.int64), self.root_node + return torch.tensor([], dtype=torch.int32), self.root_node # Radix Cache takes one ref in memory pool self.token_to_kv_pool.free(indices[last_uncached_pos:new_prefix_len])