Fix a bug with logprob streaming + chunked prefill (#2403)
This commit is contained in:
@@ -440,16 +440,11 @@ class Scheduler:
|
||||
if self.tp_rank == 0 or self.server_args.enable_dp_attention:
|
||||
recv_reqs = []
|
||||
|
||||
if self.last_batch is None:
|
||||
recv_req = self.recv_from_tokenizer.recv_pyobj()
|
||||
recv_reqs.append(recv_req)
|
||||
else:
|
||||
while True:
|
||||
try:
|
||||
recv_req = self.recv_from_tokenizer.recv_pyobj(zmq.NOBLOCK)
|
||||
except zmq.ZMQError:
|
||||
break
|
||||
recv_reqs.append(recv_req)
|
||||
while True:
|
||||
try:
|
||||
recv_req = self.recv_from_tokenizer.recv_pyobj(zmq.NOBLOCK)
|
||||
except zmq.ZMQError:
|
||||
break
|
||||
else:
|
||||
recv_reqs = None
|
||||
|
||||
@@ -949,6 +944,7 @@ class Scheduler:
|
||||
batch.next_batch_sampling_info.sampling_info_done.set()
|
||||
|
||||
def process_batch_result_prefill(self, batch: ScheduleBatch, result):
|
||||
skip_stream_req = None
|
||||
|
||||
if self.is_generation:
|
||||
logits_output, next_token_ids, bid = result
|
||||
@@ -1005,6 +1001,10 @@ class Scheduler:
|
||||
else:
|
||||
# being chunked reqs' prefill is not finished
|
||||
req.is_being_chunked -= 1
|
||||
# There is only at most one request being currently chunked.
|
||||
# Because this request does not finish prefill,
|
||||
# we don't want to stream the request currently being chunked.
|
||||
skip_stream_req = req
|
||||
|
||||
if batch.next_batch_sampling_info:
|
||||
batch.next_batch_sampling_info.update_regex_vocab_mask()
|
||||
@@ -1034,7 +1034,7 @@ class Scheduler:
|
||||
# being chunked reqs' prefill is not finished
|
||||
req.is_being_chunked -= 1
|
||||
|
||||
self.stream_output(batch.reqs)
|
||||
self.stream_output(batch.reqs, skip_stream_req)
|
||||
|
||||
def process_batch_result_decode(self, batch: ScheduleBatch, result):
|
||||
logits_output, next_token_ids, bid = result
|
||||
@@ -1179,7 +1179,7 @@ class Scheduler:
|
||||
|
||||
return num_input_logprobs
|
||||
|
||||
def stream_output(self, reqs: List[Req]):
|
||||
def stream_output(self, reqs: List[Req], skip_req: Optional[Req] = None):
|
||||
"""Stream the output to detokenizer."""
|
||||
output_rids = []
|
||||
output_meta_info: List[dict] = []
|
||||
@@ -1199,6 +1199,9 @@ class Scheduler:
|
||||
is_stream_iter = self.forward_ct_decode % self.stream_interval == 0
|
||||
|
||||
for req in reqs:
|
||||
if req is skip_req:
|
||||
continue
|
||||
|
||||
# TODO(lianmin): revisit this for overlap + retract + stream
|
||||
if req.finished() or (
|
||||
req.stream and (is_stream_iter or len(req.output_ids) == 1)
|
||||
|
||||
Reference in New Issue
Block a user