Support page size > 1 + eagle (#4908)

This commit is contained in:
Lianmin Zheng
2025-03-30 00:46:23 -07:00
committed by GitHub
parent 5ec5eaf760
commit b26bc86b36
16 changed files with 374 additions and 71 deletions

View File

@@ -185,6 +185,12 @@ class TokenToKVPoolAllocator:
if self.free_group:
self.free(torch.cat(self.free_group))
def backup_state(self):
return self.free_slots
def restore_state(self, free_slots):
self.free_slots = free_slots
def clear(self):
# The padded slot 0 is used for writing dummy outputs from padded tokens.
self.free_slots = torch.arange(

View File

@@ -218,6 +218,9 @@ class PagedTokenToKVPoolAllocator:
next_power_of_2(extend_num_tokens),
)
if self.debug_mode:
assert len(torch.unique(out_indices)) == len(out_indices)
merged_value = self.ret_values.item()
num_new_pages = merged_value >> 32
if num_new_pages > len(self.free_pages):
@@ -248,6 +251,9 @@ class PagedTokenToKVPoolAllocator:
self.page_size,
)
if self.debug_mode:
assert len(torch.unique(out_indices)) == len(out_indices)
num_new_pages = self.ret_values.item()
if num_new_pages > len(self.free_pages):
return None
@@ -265,6 +271,9 @@ class PagedTokenToKVPoolAllocator:
else:
self.free_group.append(free_index)
if self.debug_mode:
assert len(torch.unique(self.free_pages)) == len(self.free_pages)
def free_group_begin(self):
self.is_not_in_free_group = False
self.free_group = []
@@ -274,6 +283,12 @@ class PagedTokenToKVPoolAllocator:
if self.free_group:
self.free(torch.cat(self.free_group))
def backup_state(self):
return self.free_pages
def restore_state(self, free_pages):
self.free_pages = free_pages
def clear(self):
# The padded slot 0 is used for writing dummy outputs from padded tokens.
self.free_pages = torch.arange(