Simplify chunked prefill (#1667)
This commit is contained in:
@@ -456,7 +456,12 @@ class Scheduler:
|
|||||||
and not self.last_batch.is_empty()
|
and not self.last_batch.is_empty()
|
||||||
):
|
):
|
||||||
if self.current_inflight_req:
|
if self.current_inflight_req:
|
||||||
self.last_batch.filter_batch(self.current_inflight_req)
|
self.last_batch.filter_batch(
|
||||||
|
current_inflight_req=self.current_inflight_req
|
||||||
|
)
|
||||||
|
self.tree_cache.cache_unfinished_req(self.current_inflight_req)
|
||||||
|
# Inflight request keeps its rid but will get a new req_pool_idx.
|
||||||
|
self.req_to_token_pool.free(self.current_inflight_req.req_pool_idx)
|
||||||
self.batch_is_full = False
|
self.batch_is_full = False
|
||||||
if not self.last_batch.is_empty():
|
if not self.last_batch.is_empty():
|
||||||
if self.running_batch is None:
|
if self.running_batch is None:
|
||||||
@@ -728,27 +733,24 @@ class Scheduler:
|
|||||||
# Check finish conditions
|
# Check finish conditions
|
||||||
logprob_pt = 0
|
logprob_pt = 0
|
||||||
for i, req in enumerate(batch.reqs):
|
for i, req in enumerate(batch.reqs):
|
||||||
if req is not self.current_inflight_req:
|
if req.is_inflight_req > 0:
|
||||||
|
req.is_inflight_req -= 1
|
||||||
|
else:
|
||||||
# Inflight reqs' prefill is not finished
|
# Inflight reqs' prefill is not finished
|
||||||
req.completion_tokens_wo_jump_forward += 1
|
req.completion_tokens_wo_jump_forward += 1
|
||||||
req.output_ids.append(next_token_ids[i])
|
req.output_ids.append(next_token_ids[i])
|
||||||
req.check_finished()
|
req.check_finished()
|
||||||
|
|
||||||
|
if req.finished():
|
||||||
|
self.tree_cache.cache_finished_req(req)
|
||||||
|
elif not batch.decoding_reqs or req not in batch.decoding_reqs:
|
||||||
|
self.tree_cache.cache_unfinished_req(req)
|
||||||
|
|
||||||
if req.regex_fsm is not None:
|
if req.regex_fsm is not None:
|
||||||
req.regex_fsm_state = req.regex_fsm.get_next_state(
|
req.regex_fsm_state = req.regex_fsm.get_next_state(
|
||||||
req.regex_fsm_state, next_token_ids[i]
|
req.regex_fsm_state, next_token_ids[i]
|
||||||
)
|
)
|
||||||
|
|
||||||
if req.finished():
|
|
||||||
self.tree_cache.cache_finished_req(req)
|
|
||||||
elif not batch.decoding_reqs or req not in batch.decoding_reqs:
|
|
||||||
self.tree_cache.cache_unfinished_req(req)
|
|
||||||
|
|
||||||
if req.is_inflight_req > 0:
|
|
||||||
# Inflight request would get a new req idx
|
|
||||||
req.is_inflight_req -= 1
|
|
||||||
self.req_to_token_pool.free(req.req_pool_idx)
|
|
||||||
|
|
||||||
if req.return_logprob:
|
if req.return_logprob:
|
||||||
logprob_pt += self.add_logprob_return_values(
|
logprob_pt += self.add_logprob_return_values(
|
||||||
i, req, logprob_pt, next_token_ids, logits_output
|
i, req, logprob_pt, next_token_ids, logits_output
|
||||||
@@ -760,7 +762,9 @@ class Scheduler:
|
|||||||
# Check finish conditions
|
# Check finish conditions
|
||||||
for i, req in enumerate(batch.reqs):
|
for i, req in enumerate(batch.reqs):
|
||||||
req.embedding = embeddings[i]
|
req.embedding = embeddings[i]
|
||||||
if req is not self.current_inflight_req:
|
if req.is_inflight_req > 0:
|
||||||
|
req.is_inflight_req -= 1
|
||||||
|
else:
|
||||||
# Inflight reqs' prefill is not finished
|
# Inflight reqs' prefill is not finished
|
||||||
# dummy output token for embedding models
|
# dummy output token for embedding models
|
||||||
req.output_ids.append(0)
|
req.output_ids.append(0)
|
||||||
@@ -771,11 +775,6 @@ class Scheduler:
|
|||||||
else:
|
else:
|
||||||
self.tree_cache.cache_unfinished_req(req)
|
self.tree_cache.cache_unfinished_req(req)
|
||||||
|
|
||||||
if req.is_inflight_req > 0:
|
|
||||||
# Inflight request would get a new req idx
|
|
||||||
req.is_inflight_req -= 1
|
|
||||||
self.req_to_token_pool.free(req.req_pool_idx)
|
|
||||||
|
|
||||||
self.stream_output(batch)
|
self.stream_output(batch)
|
||||||
|
|
||||||
def process_batch_result_decode(self, batch: ScheduleBatch, result):
|
def process_batch_result_decode(self, batch: ScheduleBatch, result):
|
||||||
|
|||||||
Reference in New Issue
Block a user