Simplify chunked prefill (#1667)
This commit is contained in:
@@ -456,7 +456,12 @@ class Scheduler:
|
||||
and not self.last_batch.is_empty()
|
||||
):
|
||||
if self.current_inflight_req:
|
||||
self.last_batch.filter_batch(self.current_inflight_req)
|
||||
self.last_batch.filter_batch(
|
||||
current_inflight_req=self.current_inflight_req
|
||||
)
|
||||
self.tree_cache.cache_unfinished_req(self.current_inflight_req)
|
||||
# Inflight request keeps its rid but will get a new req_pool_idx.
|
||||
self.req_to_token_pool.free(self.current_inflight_req.req_pool_idx)
|
||||
self.batch_is_full = False
|
||||
if not self.last_batch.is_empty():
|
||||
if self.running_batch is None:
|
||||
@@ -728,27 +733,24 @@ class Scheduler:
|
||||
# Check finish conditions
|
||||
logprob_pt = 0
|
||||
for i, req in enumerate(batch.reqs):
|
||||
if req is not self.current_inflight_req:
|
||||
if req.is_inflight_req > 0:
|
||||
req.is_inflight_req -= 1
|
||||
else:
|
||||
# Inflight reqs' prefill is not finished
|
||||
req.completion_tokens_wo_jump_forward += 1
|
||||
req.output_ids.append(next_token_ids[i])
|
||||
req.check_finished()
|
||||
|
||||
if req.finished():
|
||||
self.tree_cache.cache_finished_req(req)
|
||||
elif not batch.decoding_reqs or req not in batch.decoding_reqs:
|
||||
self.tree_cache.cache_unfinished_req(req)
|
||||
|
||||
if req.regex_fsm is not None:
|
||||
req.regex_fsm_state = req.regex_fsm.get_next_state(
|
||||
req.regex_fsm_state, next_token_ids[i]
|
||||
)
|
||||
|
||||
if req.finished():
|
||||
self.tree_cache.cache_finished_req(req)
|
||||
elif not batch.decoding_reqs or req not in batch.decoding_reqs:
|
||||
self.tree_cache.cache_unfinished_req(req)
|
||||
|
||||
if req.is_inflight_req > 0:
|
||||
# Inflight request would get a new req idx
|
||||
req.is_inflight_req -= 1
|
||||
self.req_to_token_pool.free(req.req_pool_idx)
|
||||
|
||||
if req.return_logprob:
|
||||
logprob_pt += self.add_logprob_return_values(
|
||||
i, req, logprob_pt, next_token_ids, logits_output
|
||||
@@ -760,7 +762,9 @@ class Scheduler:
|
||||
# Check finish conditions
|
||||
for i, req in enumerate(batch.reqs):
|
||||
req.embedding = embeddings[i]
|
||||
if req is not self.current_inflight_req:
|
||||
if req.is_inflight_req > 0:
|
||||
req.is_inflight_req -= 1
|
||||
else:
|
||||
# Inflight reqs' prefill is not finished
|
||||
# dummy output token for embedding models
|
||||
req.output_ids.append(0)
|
||||
@@ -771,11 +775,6 @@ class Scheduler:
|
||||
else:
|
||||
self.tree_cache.cache_unfinished_req(req)
|
||||
|
||||
if req.is_inflight_req > 0:
|
||||
# Inflight request would get a new req idx
|
||||
req.is_inflight_req -= 1
|
||||
self.req_to_token_pool.free(req.req_pool_idx)
|
||||
|
||||
self.stream_output(batch)
|
||||
|
||||
def process_batch_result_decode(self, batch: ScheduleBatch, result):
|
||||
|
||||
Reference in New Issue
Block a user