Fix mixed chunked prefill in overlap mode (#2158)
This commit is contained in:
4
.github/workflows/pr-test.yml
vendored
4
.github/workflows/pr-test.yml
vendored
@@ -50,7 +50,7 @@ jobs:
|
||||
timeout-minutes: 25
|
||||
run: |
|
||||
cd test/srt
|
||||
python3 run_suite.py --suite minimal --range-begin 0 --range-end 6
|
||||
python3 run_suite.py --suite minimal --range-begin 0 --range-end 7
|
||||
|
||||
unit-test-backend-part-2:
|
||||
if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request'
|
||||
@@ -67,7 +67,7 @@ jobs:
|
||||
timeout-minutes: 25
|
||||
run: |
|
||||
cd test/srt
|
||||
python3 run_suite.py --suite minimal --range-begin 6 --range-end 14
|
||||
python3 run_suite.py --suite minimal --range-begin 7 --range-end 14
|
||||
|
||||
unit-test-backend-part-3:
|
||||
if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request'
|
||||
|
||||
@@ -729,10 +729,13 @@ class ScheduleBatch:
|
||||
self.input_ids = input_ids
|
||||
self.out_cache_loc = out_cache_loc
|
||||
|
||||
# For overlap scheduler, the output_ids has one step delay
|
||||
delta = 0 if self.enable_overlap else -1
|
||||
|
||||
# NOTE: prefix_indices is what has been cached, but we don't cache each decode step
|
||||
self.prefix_lens.extend(
|
||||
[
|
||||
len(r.origin_input_ids) + len(r.output_ids) - 1
|
||||
len(r.origin_input_ids) + len(r.output_ids) + delta
|
||||
for r in running_batch.reqs
|
||||
]
|
||||
)
|
||||
|
||||
@@ -848,7 +848,12 @@ class Scheduler:
|
||||
new_batch.prepare_for_extend()
|
||||
|
||||
# Mixed-style chunked prefill
|
||||
if self.is_mixed_chunk and self.running_batch is not None:
|
||||
if (
|
||||
self.is_mixed_chunk
|
||||
and self.running_batch is not None
|
||||
and not (new_batch.return_logprob or self.running_batch.return_logprob)
|
||||
):
|
||||
# TODO (lianmin): support return_logprob + mixed chunked prefill
|
||||
self.running_batch.filter_batch()
|
||||
if not self.running_batch.is_empty():
|
||||
self.running_batch.prepare_for_decode()
|
||||
@@ -979,7 +984,10 @@ class Scheduler:
|
||||
continue
|
||||
|
||||
if self.is_mixed_chunk and self.enable_overlap and req.finished():
|
||||
raise ValueError("Unhandled error!")
|
||||
# Free the one delayed token for the mixed decode batch
|
||||
j = len(batch.out_cache_loc) - len(batch.reqs) + i
|
||||
self.token_to_kv_pool.free(batch.out_cache_loc[j : j + 1])
|
||||
continue
|
||||
|
||||
if req.is_being_chunked <= 0:
|
||||
req.completion_tokens_wo_jump_forward += 1
|
||||
@@ -992,7 +1000,6 @@ class Scheduler:
|
||||
self.tree_cache.cache_unfinished_req(req)
|
||||
|
||||
if req.return_logprob:
|
||||
# TODO (lianmin): need to think the case w/ mixed chunked prefill
|
||||
logprob_pt += self.add_logprob_return_values(
|
||||
i, req, logprob_pt, next_token_ids, logits_output
|
||||
)
|
||||
|
||||
@@ -199,12 +199,6 @@ class ServerArgs:
|
||||
"Overlap schedule is disabled."
|
||||
)
|
||||
|
||||
if self.enable_mixed_chunk:
|
||||
logger.info(
|
||||
"Overlap schedule is disabled because mixed-style chunked prefill is enabled."
|
||||
)
|
||||
self.disable_overlap_schedule = True
|
||||
|
||||
@staticmethod
|
||||
def add_cli_args(parser: argparse.ArgumentParser):
|
||||
# Model and port args
|
||||
|
||||
Reference in New Issue
Block a user