Fix the race condition in overlap mode (#1712)

This commit is contained in:
Lianmin Zheng
2024-10-19 06:50:56 -07:00
committed by GitHub
parent 3db43d1b08
commit 769bf11c05
6 changed files with 21 additions and 38 deletions

View File

@@ -261,12 +261,7 @@ class Scheduler:
self.resolve_next_token_ids = (
lambda bid, x: self.tp_worker.resolve_future_token_ids(bid)
)
def cache_finished_req(req):
free_delta = int(self.running_batch and req in self.cur_batch.reqs)
self.tree_cache.cache_finished_req(req, free_delta=free_delta)
self.cache_finished_req = cache_finished_req
self.cache_finished_req = self.tree_cache.cache_finished_req
else:
self.forward_batch_generation = self.tp_worker.forward_batch_generation
self.resolve_next_token_ids = lambda bid, x: x.tolist()
@@ -798,7 +793,6 @@ class Scheduler:
i, req, logprob_pt, next_token_ids, logits_output
)
else: # embedding or reward model
assert batch.extend_num_tokens != 0
embeddings, bid = result
embeddings = embeddings.tolist()
@@ -838,6 +832,7 @@ class Scheduler:
# Check finish condition
for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
if self.server_args.enable_overlap_schedule and req.finished():
self.token_to_kv_pool.free(batch.out_cache_loc[i : i + 1])
continue
req.completion_tokens_wo_jump_forward += 1