Bugfix: Fix Type consistency for KV indices in SWARadixCache (#11452)
Co-authored-by: yizhang2077 <1109276519@qq.com>
This commit is contained in:
@@ -449,11 +449,13 @@ class SWARadixCache(BasePrefixCache):
|
||||
|
||||
if self.page_size != 1:
|
||||
page_aligned_len = actual_kv_len // self.page_size * self.page_size
|
||||
page_aligned_kv_indices = kv_indices[:page_aligned_len].clone()
|
||||
page_aligned_kv_indices = kv_indices[:page_aligned_len].to(
|
||||
dtype=torch.int64, copy=True
|
||||
)
|
||||
self.token_to_kv_pool_allocator.free(kv_indices[page_aligned_len:])
|
||||
else:
|
||||
page_aligned_len = actual_kv_len
|
||||
page_aligned_kv_indices = kv_indices.clone()
|
||||
page_aligned_kv_indices = kv_indices.to(dtype=torch.int64, copy=True)
|
||||
if self.is_eagle:
|
||||
self.token_to_kv_pool_allocator.free(kv_indices[page_aligned_len:])
|
||||
|
||||
@@ -502,10 +504,12 @@ class SWARadixCache(BasePrefixCache):
|
||||
|
||||
if self.page_size != 1:
|
||||
page_aligned_len = actual_kv_len // self.page_size * self.page_size
|
||||
page_aligned_kv_indices = kv_indices[:page_aligned_len].clone()
|
||||
page_aligned_kv_indices = kv_indices[:page_aligned_len].to(
|
||||
dtype=torch.int64, copy=True
|
||||
)
|
||||
else:
|
||||
page_aligned_len = actual_kv_len
|
||||
page_aligned_kv_indices = kv_indices.clone()
|
||||
page_aligned_kv_indices = kv_indices.to(dtype=torch.int64, copy=True)
|
||||
|
||||
# For EAGLE, the page_aligned_len is for the bigram key, the normal key len should +1
|
||||
page_aligned_token_len = (
|
||||
|
||||
Reference in New Issue
Block a user