Support page size > 1 + eagle (#4908)

This commit is contained in:
Lianmin Zheng
2025-03-30 00:46:23 -07:00
committed by GitHub
parent 5ec5eaf760
commit b26bc86b36
16 changed files with 374 additions and 71 deletions

View File

@@ -740,11 +740,14 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
)
return req_pool_indices
def alloc_token_slots(self, num_tokens: int):
def alloc_token_slots(self, num_tokens: int, backup_state: bool = False):
if self.token_to_kv_pool_allocator.available_size() < num_tokens:
if self.tree_cache is not None:
self.tree_cache.evict(num_tokens)
if backup_state:
state = self.token_to_kv_pool_allocator.backup_state()
out_cache_loc = self.token_to_kv_pool_allocator.alloc(num_tokens)
if out_cache_loc is None:
phase_str = "Prefill" if self.forward_mode.is_extend() else "Decode"
@@ -758,7 +761,10 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
self.tree_cache.pretty_print()
raise RuntimeError(error_msg)
return out_cache_loc
if backup_state:
return out_cache_loc, state
else:
return out_cache_loc
def alloc_paged_token_slots_extend(
self,
@@ -766,6 +772,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
seq_lens: torch.Tensor,
last_loc: torch.Tensor,
extend_num_tokens: int,
backup_state: bool = False,
):
if (
self.token_to_kv_pool_allocator.available_size()
@@ -778,6 +785,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
+ len(seq_lens) * self.token_to_kv_pool_allocator.page_size,
)
if backup_state:
state = self.token_to_kv_pool_allocator.backup_state()
out_cache_loc = self.token_to_kv_pool_allocator.alloc_extend(
prefix_lens, seq_lens, last_loc, extend_num_tokens
)
@@ -791,12 +801,17 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
)
logger.error(error_msg)
raise RuntimeError(error_msg)
return out_cache_loc
if backup_state:
return out_cache_loc, state
else:
return out_cache_loc
def alloc_paged_token_slots_decode(
self,
seq_lens: torch.Tensor,
last_loc: torch.Tensor,
backup_state: bool = False,
):
if (
self.token_to_kv_pool_allocator.available_size()
@@ -806,8 +821,11 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
self.tree_cache.evict(
len(seq_lens) * self.token_to_kv_pool_allocator.page_size,
)
out_cache_loc = self.token_to_kv_pool_allocator.alloc_decode(seq_lens, last_loc)
if backup_state:
state = self.token_to_kv_pool_allocator.backup_state()
out_cache_loc = self.token_to_kv_pool_allocator.alloc_decode(seq_lens, last_loc)
if out_cache_loc is None:
error_msg = (
f"Decode out of memory. Try to lower your batch size.\n"
@@ -818,7 +836,11 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
)
logger.error(error_msg)
raise RuntimeError(error_msg)
return out_cache_loc
if backup_state:
return out_cache_loc, state
else:
return out_cache_loc
def prepare_encoder_info_extend(self, input_ids: List[int], seq_lens: List[int]):
self.encoder_lens_cpu = []

View File

@@ -1110,7 +1110,7 @@ class Scheduler(
)
if memory_leak:
msg = (
"KV cache pool leak detected! "
"token_to_kv_pool_allocator memory leak detected! "
f"{available_size=}, {protected_size=}, {self.max_total_num_tokens=}\n"
f"{self.token_to_kv_pool_allocator.available_size()=}\n"
f"{self.tree_cache.evictable_size()=}\n"
@@ -1121,7 +1121,7 @@ class Scheduler(
if len(self.req_to_token_pool.free_slots) != self.req_to_token_pool.size:
msg = (
"Memory pool leak detected!"
"req_to_token_pool memory leak detected!"
f"available_size={len(self.req_to_token_pool.free_slots)}, "
f"total_size={self.req_to_token_pool.size}\n"
)