Support page size > 1 + eagle (#4908)
This commit is contained in:
@@ -185,6 +185,12 @@ class TokenToKVPoolAllocator:
|
||||
if self.free_group:
|
||||
self.free(torch.cat(self.free_group))
|
||||
|
||||
def backup_state(self):
|
||||
return self.free_slots
|
||||
|
||||
def restore_state(self, free_slots):
|
||||
self.free_slots = free_slots
|
||||
|
||||
def clear(self):
|
||||
# The padded slot 0 is used for writing dummy outputs from padded tokens.
|
||||
self.free_slots = torch.arange(
|
||||
|
||||
@@ -218,6 +218,9 @@ class PagedTokenToKVPoolAllocator:
|
||||
next_power_of_2(extend_num_tokens),
|
||||
)
|
||||
|
||||
if self.debug_mode:
|
||||
assert len(torch.unique(out_indices)) == len(out_indices)
|
||||
|
||||
merged_value = self.ret_values.item()
|
||||
num_new_pages = merged_value >> 32
|
||||
if num_new_pages > len(self.free_pages):
|
||||
@@ -248,6 +251,9 @@ class PagedTokenToKVPoolAllocator:
|
||||
self.page_size,
|
||||
)
|
||||
|
||||
if self.debug_mode:
|
||||
assert len(torch.unique(out_indices)) == len(out_indices)
|
||||
|
||||
num_new_pages = self.ret_values.item()
|
||||
if num_new_pages > len(self.free_pages):
|
||||
return None
|
||||
@@ -265,6 +271,9 @@ class PagedTokenToKVPoolAllocator:
|
||||
else:
|
||||
self.free_group.append(free_index)
|
||||
|
||||
if self.debug_mode:
|
||||
assert len(torch.unique(self.free_pages)) == len(self.free_pages)
|
||||
|
||||
def free_group_begin(self):
|
||||
self.is_not_in_free_group = False
|
||||
self.free_group = []
|
||||
@@ -274,6 +283,12 @@ class PagedTokenToKVPoolAllocator:
|
||||
if self.free_group:
|
||||
self.free(torch.cat(self.free_group))
|
||||
|
||||
def backup_state(self):
|
||||
return self.free_pages
|
||||
|
||||
def restore_state(self, free_pages):
|
||||
self.free_pages = free_pages
|
||||
|
||||
def clear(self):
|
||||
# The padded slot 0 is used for writing dummy outputs from padded tokens.
|
||||
self.free_pages = torch.arange(
|
||||
|
||||
Reference in New Issue
Block a user