Support page size > 1 + eagle (#4908)
This commit is contained in:
@@ -740,11 +740,14 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
)
|
||||
return req_pool_indices
|
||||
|
||||
def alloc_token_slots(self, num_tokens: int):
|
||||
def alloc_token_slots(self, num_tokens: int, backup_state: bool = False):
|
||||
if self.token_to_kv_pool_allocator.available_size() < num_tokens:
|
||||
if self.tree_cache is not None:
|
||||
self.tree_cache.evict(num_tokens)
|
||||
|
||||
if backup_state:
|
||||
state = self.token_to_kv_pool_allocator.backup_state()
|
||||
|
||||
out_cache_loc = self.token_to_kv_pool_allocator.alloc(num_tokens)
|
||||
if out_cache_loc is None:
|
||||
phase_str = "Prefill" if self.forward_mode.is_extend() else "Decode"
|
||||
@@ -758,7 +761,10 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
self.tree_cache.pretty_print()
|
||||
raise RuntimeError(error_msg)
|
||||
|
||||
return out_cache_loc
|
||||
if backup_state:
|
||||
return out_cache_loc, state
|
||||
else:
|
||||
return out_cache_loc
|
||||
|
||||
def alloc_paged_token_slots_extend(
|
||||
self,
|
||||
@@ -766,6 +772,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
seq_lens: torch.Tensor,
|
||||
last_loc: torch.Tensor,
|
||||
extend_num_tokens: int,
|
||||
backup_state: bool = False,
|
||||
):
|
||||
if (
|
||||
self.token_to_kv_pool_allocator.available_size()
|
||||
@@ -778,6 +785,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
+ len(seq_lens) * self.token_to_kv_pool_allocator.page_size,
|
||||
)
|
||||
|
||||
if backup_state:
|
||||
state = self.token_to_kv_pool_allocator.backup_state()
|
||||
|
||||
out_cache_loc = self.token_to_kv_pool_allocator.alloc_extend(
|
||||
prefix_lens, seq_lens, last_loc, extend_num_tokens
|
||||
)
|
||||
@@ -791,12 +801,17 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
)
|
||||
logger.error(error_msg)
|
||||
raise RuntimeError(error_msg)
|
||||
return out_cache_loc
|
||||
|
||||
if backup_state:
|
||||
return out_cache_loc, state
|
||||
else:
|
||||
return out_cache_loc
|
||||
|
||||
def alloc_paged_token_slots_decode(
|
||||
self,
|
||||
seq_lens: torch.Tensor,
|
||||
last_loc: torch.Tensor,
|
||||
backup_state: bool = False,
|
||||
):
|
||||
if (
|
||||
self.token_to_kv_pool_allocator.available_size()
|
||||
@@ -806,8 +821,11 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
self.tree_cache.evict(
|
||||
len(seq_lens) * self.token_to_kv_pool_allocator.page_size,
|
||||
)
|
||||
out_cache_loc = self.token_to_kv_pool_allocator.alloc_decode(seq_lens, last_loc)
|
||||
|
||||
if backup_state:
|
||||
state = self.token_to_kv_pool_allocator.backup_state()
|
||||
|
||||
out_cache_loc = self.token_to_kv_pool_allocator.alloc_decode(seq_lens, last_loc)
|
||||
if out_cache_loc is None:
|
||||
error_msg = (
|
||||
f"Decode out of memory. Try to lower your batch size.\n"
|
||||
@@ -818,7 +836,11 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
)
|
||||
logger.error(error_msg)
|
||||
raise RuntimeError(error_msg)
|
||||
return out_cache_loc
|
||||
|
||||
if backup_state:
|
||||
return out_cache_loc, state
|
||||
else:
|
||||
return out_cache_loc
|
||||
|
||||
def prepare_encoder_info_extend(self, input_ids: List[int], seq_lens: List[int]):
|
||||
self.encoder_lens_cpu = []
|
||||
|
||||
Reference in New Issue
Block a user