Fix a bug with logprob streaming + chunked prefill (#2403)

This commit is contained in:
Lianmin Zheng
2024-12-08 03:55:27 -08:00
committed by GitHub
parent 61dec545b0
commit a2486eb58f
3 changed files with 24 additions and 13 deletions

View File

@@ -440,16 +440,11 @@ class Scheduler:
if self.tp_rank == 0 or self.server_args.enable_dp_attention:
recv_reqs = []
if self.last_batch is None:
recv_req = self.recv_from_tokenizer.recv_pyobj()
recv_reqs.append(recv_req)
else:
while True:
try:
recv_req = self.recv_from_tokenizer.recv_pyobj(zmq.NOBLOCK)
except zmq.ZMQError:
break
recv_reqs.append(recv_req)
while True:
try:
recv_req = self.recv_from_tokenizer.recv_pyobj(zmq.NOBLOCK)
except zmq.ZMQError:
break
else:
recv_reqs = None
@@ -949,6 +944,7 @@ class Scheduler:
batch.next_batch_sampling_info.sampling_info_done.set()
def process_batch_result_prefill(self, batch: ScheduleBatch, result):
skip_stream_req = None
if self.is_generation:
logits_output, next_token_ids, bid = result
@@ -1005,6 +1001,10 @@ class Scheduler:
else:
# being chunked reqs' prefill is not finished
req.is_being_chunked -= 1
# There is only at most one request being currently chunked.
# Because this request does not finish prefill,
# we don't want to stream the request currently being chunked.
skip_stream_req = req
if batch.next_batch_sampling_info:
batch.next_batch_sampling_info.update_regex_vocab_mask()
@@ -1034,7 +1034,7 @@ class Scheduler:
# being chunked reqs' prefill is not finished
req.is_being_chunked -= 1
self.stream_output(batch.reqs)
self.stream_output(batch.reqs, skip_stream_req)
def process_batch_result_decode(self, batch: ScheduleBatch, result):
logits_output, next_token_ids, bid = result
@@ -1179,7 +1179,7 @@ class Scheduler:
return num_input_logprobs
def stream_output(self, reqs: List[Req]):
def stream_output(self, reqs: List[Req], skip_req: Optional[Req] = None):
"""Stream the output to detokenizer."""
output_rids = []
output_meta_info: List[dict] = []
@@ -1199,6 +1199,9 @@ class Scheduler:
is_stream_iter = self.forward_ct_decode % self.stream_interval == 0
for req in reqs:
if req is skip_req:
continue
# TODO(lianmin): revisit this for overlap + retract + stream
if req.finished() or (
req.stream and (is_stream_iter or len(req.output_ids) == 1)