diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 205f31797..bd2158789 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -780,7 +780,7 @@ def top_k_top_p_sampling_from_probs_torch( sampled_index = torch.multinomial(probs_sort, num_samples=1) except RuntimeError: batch_next_token_ids = torch.zeros( - (probs_sort.shape[0],), dtype=torch.int64, device=probs.device + (probs_sort.shape[0],), dtype=torch.int32, device=probs.device ) success = torch.zeros(probs.shape[0], dtype=torch.bool, device=probs.device) return batch_next_token_ids, success diff --git a/python/sglang/srt/mem_cache/radix_cache.py b/python/sglang/srt/mem_cache/radix_cache.py index c6fc3191b..347ae002e 100644 --- a/python/sglang/srt/mem_cache/radix_cache.py +++ b/python/sglang/srt/mem_cache/radix_cache.py @@ -74,7 +74,7 @@ class RadixCache(BasePrefixCache): if value: value = torch.concat(value) else: - value = torch.tensor([], dtype=torch.int64) + value = torch.tensor([], dtype=torch.int32) return value, last_node[0] def insert(self, key, value=None): @@ -102,7 +102,7 @@ class RadixCache(BasePrefixCache): if del_in_memory_pool: self.token_to_kv_pool.free(indices) else: - return torch.tensor([], dtype=torch.int64), self.root_node + return torch.tensor([], dtype=torch.int32), self.root_node # Radix Cache takes one ref in memory pool self.token_to_kv_pool.free(indices[last_uncached_pos:new_prefix_len])