Fix retract for page size > 1 (#4914)

This commit is contained in:
Lianmin Zheng
2025-03-30 02:57:15 -07:00
committed by GitHub
parent b26bc86b36
commit 4ede6770cd
10 changed files with 68 additions and 120 deletions

View File

@@ -599,6 +599,7 @@ class Req:
self.extend_logprob_start_len = 0
self.is_chunked = 0
self.req_pool_idx = None
self.already_computed = 0
def __repr__(self):
return (
@@ -960,8 +961,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
# If req.input_embeds is already a list, append its content directly
input_embeds.extend(req.input_embeds) # Use extend to avoid nesting
if req.is_retracted:
req.already_computed = 0
req.cached_tokens += pre_len - req.already_computed
req.already_computed = seq_len
req.is_retracted = False
@@ -1189,7 +1188,11 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
self.req_to_token_pool.free(req.req_pool_idx)
else:
# TODO: apply more fine-grained retraction
last_uncached_pos = len(req.prefix_indices)
last_uncached_pos = (
(len(req.prefix_indices) + server_args.page_size - 1)
// server_args.page_size
* server_args.page_size
)
token_indices = self.req_to_token_pool.req_to_token[
req.req_pool_idx, last_uncached_pos : seq_lens_cpu[idx]
]