Fix retract for page size > 1 (#4914)
This commit is contained in:
@@ -599,6 +599,7 @@ class Req:
|
||||
self.extend_logprob_start_len = 0
|
||||
self.is_chunked = 0
|
||||
self.req_pool_idx = None
|
||||
self.already_computed = 0
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
@@ -960,8 +961,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
# If req.input_embeds is already a list, append its content directly
|
||||
input_embeds.extend(req.input_embeds) # Use extend to avoid nesting
|
||||
|
||||
if req.is_retracted:
|
||||
req.already_computed = 0
|
||||
req.cached_tokens += pre_len - req.already_computed
|
||||
req.already_computed = seq_len
|
||||
req.is_retracted = False
|
||||
@@ -1189,7 +1188,11 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
self.req_to_token_pool.free(req.req_pool_idx)
|
||||
else:
|
||||
# TODO: apply more fine-grained retraction
|
||||
last_uncached_pos = len(req.prefix_indices)
|
||||
last_uncached_pos = (
|
||||
(len(req.prefix_indices) + server_args.page_size - 1)
|
||||
// server_args.page_size
|
||||
* server_args.page_size
|
||||
)
|
||||
token_indices = self.req_to_token_pool.req_to_token[
|
||||
req.req_pool_idx, last_uncached_pos : seq_lens_cpu[idx]
|
||||
]
|
||||
|
||||
Reference in New Issue
Block a user