diff --git a/python/sglang/srt/mem_cache/swa_radix_cache.py b/python/sglang/srt/mem_cache/swa_radix_cache.py index 764def85c..928b207d8 100644 --- a/python/sglang/srt/mem_cache/swa_radix_cache.py +++ b/python/sglang/srt/mem_cache/swa_radix_cache.py @@ -449,11 +449,13 @@ class SWARadixCache(BasePrefixCache): if self.page_size != 1: page_aligned_len = actual_kv_len // self.page_size * self.page_size - page_aligned_kv_indices = kv_indices[:page_aligned_len].clone() + page_aligned_kv_indices = kv_indices[:page_aligned_len].to( + dtype=torch.int64, copy=True + ) self.token_to_kv_pool_allocator.free(kv_indices[page_aligned_len:]) else: page_aligned_len = actual_kv_len - page_aligned_kv_indices = kv_indices.clone() + page_aligned_kv_indices = kv_indices.to(dtype=torch.int64, copy=True) if self.is_eagle: self.token_to_kv_pool_allocator.free(kv_indices[page_aligned_len:]) @@ -502,10 +504,12 @@ class SWARadixCache(BasePrefixCache): if self.page_size != 1: page_aligned_len = actual_kv_len // self.page_size * self.page_size - page_aligned_kv_indices = kv_indices[:page_aligned_len].clone() + page_aligned_kv_indices = kv_indices[:page_aligned_len].to( + dtype=torch.int64, copy=True + ) else: page_aligned_len = actual_kv_len - page_aligned_kv_indices = kv_indices.clone() + page_aligned_kv_indices = kv_indices.to(dtype=torch.int64, copy=True) # For EAGLE, the page_aligned_len is for the bigram key, the normal key len should +1 page_aligned_token_len = (