Fix a bug with logprob streaming + chunked prefill (#2403)
This commit is contained in:
@@ -321,6 +321,8 @@ async def async_request_sglang_generate(
|
||||
},
|
||||
"stream": not args.disable_stream,
|
||||
"lora_path": request_func_input.lora_name,
|
||||
"return_logprob": args.return_logprob,
|
||||
"logprob_start_len": -1,
|
||||
**request_func_input.extra_request_body,
|
||||
}
|
||||
headers = {}
|
||||
@@ -911,7 +913,7 @@ async def benchmark(
|
||||
prompt=test_prompt,
|
||||
api_url=api_url,
|
||||
prompt_len=test_prompt_len,
|
||||
output_len=test_output_len,
|
||||
output_len=min(test_output_len, 32),
|
||||
lora_name=lora_name,
|
||||
extra_request_body=extra_request_body,
|
||||
)
|
||||
@@ -1413,6 +1415,11 @@ if __name__ == "__main__":
|
||||
action="store_true",
|
||||
help="Disable ignoring EOS.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--return-logprob",
|
||||
action="store_true",
|
||||
help="Return logprob.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--extra-request-body",
|
||||
metavar='{"key1": "value1", "key2": "value2"}',
|
||||
|
||||
@@ -440,16 +440,11 @@ class Scheduler:
|
||||
if self.tp_rank == 0 or self.server_args.enable_dp_attention:
|
||||
recv_reqs = []
|
||||
|
||||
if self.last_batch is None:
|
||||
recv_req = self.recv_from_tokenizer.recv_pyobj()
|
||||
recv_reqs.append(recv_req)
|
||||
else:
|
||||
while True:
|
||||
try:
|
||||
recv_req = self.recv_from_tokenizer.recv_pyobj(zmq.NOBLOCK)
|
||||
except zmq.ZMQError:
|
||||
break
|
||||
recv_reqs.append(recv_req)
|
||||
while True:
|
||||
try:
|
||||
recv_req = self.recv_from_tokenizer.recv_pyobj(zmq.NOBLOCK)
|
||||
except zmq.ZMQError:
|
||||
break
|
||||
else:
|
||||
recv_reqs = None
|
||||
|
||||
@@ -949,6 +944,7 @@ class Scheduler:
|
||||
batch.next_batch_sampling_info.sampling_info_done.set()
|
||||
|
||||
def process_batch_result_prefill(self, batch: ScheduleBatch, result):
|
||||
skip_stream_req = None
|
||||
|
||||
if self.is_generation:
|
||||
logits_output, next_token_ids, bid = result
|
||||
@@ -1005,6 +1001,10 @@ class Scheduler:
|
||||
else:
|
||||
# being chunked reqs' prefill is not finished
|
||||
req.is_being_chunked -= 1
|
||||
# There is only at most one request being currently chunked.
|
||||
# Because this request does not finish prefill,
|
||||
# we don't want to stream the request currently being chunked.
|
||||
skip_stream_req = req
|
||||
|
||||
if batch.next_batch_sampling_info:
|
||||
batch.next_batch_sampling_info.update_regex_vocab_mask()
|
||||
@@ -1034,7 +1034,7 @@ class Scheduler:
|
||||
# being chunked reqs' prefill is not finished
|
||||
req.is_being_chunked -= 1
|
||||
|
||||
self.stream_output(batch.reqs)
|
||||
self.stream_output(batch.reqs, skip_stream_req)
|
||||
|
||||
def process_batch_result_decode(self, batch: ScheduleBatch, result):
|
||||
logits_output, next_token_ids, bid = result
|
||||
@@ -1179,7 +1179,7 @@ class Scheduler:
|
||||
|
||||
return num_input_logprobs
|
||||
|
||||
def stream_output(self, reqs: List[Req]):
|
||||
def stream_output(self, reqs: List[Req], skip_req: Optional[Req] = None):
|
||||
"""Stream the output to detokenizer."""
|
||||
output_rids = []
|
||||
output_meta_info: List[dict] = []
|
||||
@@ -1199,6 +1199,9 @@ class Scheduler:
|
||||
is_stream_iter = self.forward_ct_decode % self.stream_interval == 0
|
||||
|
||||
for req in reqs:
|
||||
if req is skip_req:
|
||||
continue
|
||||
|
||||
# TODO(lianmin): revisit this for overlap + retract + stream
|
||||
if req.finished() or (
|
||||
req.stream and (is_stream_iter or len(req.output_ids) == 1)
|
||||
|
||||
@@ -568,6 +568,7 @@ def run_bench_serving(
|
||||
disable_tqdm=False,
|
||||
disable_stream=disable_stream,
|
||||
disable_ignore_eos=False,
|
||||
return_logprob=False,
|
||||
lora_name=None,
|
||||
extra_request_body=None,
|
||||
profile=None,
|
||||
|
||||
Reference in New Issue
Block a user