Bugfix: Fix Type consistency for KV indices in SWARadixCache (#11452)

Co-authored-by: yizhang2077 <1109276519@qq.com>
This commit is contained in:
hzh0425
2025-10-12 23:19:44 +08:00
committed by GitHub
parent 5a6ec8f999
commit f5b34a510c

View File

@@ -449,11 +449,13 @@ class SWARadixCache(BasePrefixCache):
if self.page_size != 1:
page_aligned_len = actual_kv_len // self.page_size * self.page_size
page_aligned_kv_indices = kv_indices[:page_aligned_len].clone()
page_aligned_kv_indices = kv_indices[:page_aligned_len].to(
dtype=torch.int64, copy=True
)
self.token_to_kv_pool_allocator.free(kv_indices[page_aligned_len:])
else:
page_aligned_len = actual_kv_len
page_aligned_kv_indices = kv_indices.clone()
page_aligned_kv_indices = kv_indices.to(dtype=torch.int64, copy=True)
if self.is_eagle:
self.token_to_kv_pool_allocator.free(kv_indices[page_aligned_len:])
@@ -502,10 +504,12 @@ class SWARadixCache(BasePrefixCache):
if self.page_size != 1:
page_aligned_len = actual_kv_len // self.page_size * self.page_size
page_aligned_kv_indices = kv_indices[:page_aligned_len].clone()
page_aligned_kv_indices = kv_indices[:page_aligned_len].to(
dtype=torch.int64, copy=True
)
else:
page_aligned_len = actual_kv_len
page_aligned_kv_indices = kv_indices.clone()
page_aligned_kv_indices = kv_indices.to(dtype=torch.int64, copy=True)
# For EAGLE, the page_aligned_len is for the bigram key, the normal key len should +1
page_aligned_token_len = (