Fix the race condition in overlap mode (#1712)
This commit is contained in:
@@ -261,12 +261,7 @@ class Scheduler:
|
||||
self.resolve_next_token_ids = (
|
||||
lambda bid, x: self.tp_worker.resolve_future_token_ids(bid)
|
||||
)
|
||||
|
||||
def cache_finished_req(req):
|
||||
free_delta = int(self.running_batch and req in self.cur_batch.reqs)
|
||||
self.tree_cache.cache_finished_req(req, free_delta=free_delta)
|
||||
|
||||
self.cache_finished_req = cache_finished_req
|
||||
self.cache_finished_req = self.tree_cache.cache_finished_req
|
||||
else:
|
||||
self.forward_batch_generation = self.tp_worker.forward_batch_generation
|
||||
self.resolve_next_token_ids = lambda bid, x: x.tolist()
|
||||
@@ -798,7 +793,6 @@ class Scheduler:
|
||||
i, req, logprob_pt, next_token_ids, logits_output
|
||||
)
|
||||
else: # embedding or reward model
|
||||
assert batch.extend_num_tokens != 0
|
||||
embeddings, bid = result
|
||||
embeddings = embeddings.tolist()
|
||||
|
||||
@@ -838,6 +832,7 @@ class Scheduler:
|
||||
# Check finish condition
|
||||
for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
|
||||
if self.server_args.enable_overlap_schedule and req.finished():
|
||||
self.token_to_kv_pool.free(batch.out_cache_loc[i : i + 1])
|
||||
continue
|
||||
|
||||
req.completion_tokens_wo_jump_forward += 1
|
||||
|
||||
Reference in New Issue
Block a user