Fix a bug with logprob streaming + chunked prefill (#2403)
This commit is contained in:
@@ -321,6 +321,8 @@ async def async_request_sglang_generate(
|
|||||||
},
|
},
|
||||||
"stream": not args.disable_stream,
|
"stream": not args.disable_stream,
|
||||||
"lora_path": request_func_input.lora_name,
|
"lora_path": request_func_input.lora_name,
|
||||||
|
"return_logprob": args.return_logprob,
|
||||||
|
"logprob_start_len": -1,
|
||||||
**request_func_input.extra_request_body,
|
**request_func_input.extra_request_body,
|
||||||
}
|
}
|
||||||
headers = {}
|
headers = {}
|
||||||
@@ -911,7 +913,7 @@ async def benchmark(
|
|||||||
prompt=test_prompt,
|
prompt=test_prompt,
|
||||||
api_url=api_url,
|
api_url=api_url,
|
||||||
prompt_len=test_prompt_len,
|
prompt_len=test_prompt_len,
|
||||||
output_len=test_output_len,
|
output_len=min(test_output_len, 32),
|
||||||
lora_name=lora_name,
|
lora_name=lora_name,
|
||||||
extra_request_body=extra_request_body,
|
extra_request_body=extra_request_body,
|
||||||
)
|
)
|
||||||
@@ -1413,6 +1415,11 @@ if __name__ == "__main__":
|
|||||||
action="store_true",
|
action="store_true",
|
||||||
help="Disable ignoring EOS.",
|
help="Disable ignoring EOS.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--return-logprob",
|
||||||
|
action="store_true",
|
||||||
|
help="Return logprob.",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--extra-request-body",
|
"--extra-request-body",
|
||||||
metavar='{"key1": "value1", "key2": "value2"}',
|
metavar='{"key1": "value1", "key2": "value2"}',
|
||||||
|
|||||||
@@ -440,16 +440,11 @@ class Scheduler:
|
|||||||
if self.tp_rank == 0 or self.server_args.enable_dp_attention:
|
if self.tp_rank == 0 or self.server_args.enable_dp_attention:
|
||||||
recv_reqs = []
|
recv_reqs = []
|
||||||
|
|
||||||
if self.last_batch is None:
|
while True:
|
||||||
recv_req = self.recv_from_tokenizer.recv_pyobj()
|
try:
|
||||||
recv_reqs.append(recv_req)
|
recv_req = self.recv_from_tokenizer.recv_pyobj(zmq.NOBLOCK)
|
||||||
else:
|
except zmq.ZMQError:
|
||||||
while True:
|
break
|
||||||
try:
|
|
||||||
recv_req = self.recv_from_tokenizer.recv_pyobj(zmq.NOBLOCK)
|
|
||||||
except zmq.ZMQError:
|
|
||||||
break
|
|
||||||
recv_reqs.append(recv_req)
|
|
||||||
else:
|
else:
|
||||||
recv_reqs = None
|
recv_reqs = None
|
||||||
|
|
||||||
@@ -949,6 +944,7 @@ class Scheduler:
|
|||||||
batch.next_batch_sampling_info.sampling_info_done.set()
|
batch.next_batch_sampling_info.sampling_info_done.set()
|
||||||
|
|
||||||
def process_batch_result_prefill(self, batch: ScheduleBatch, result):
|
def process_batch_result_prefill(self, batch: ScheduleBatch, result):
|
||||||
|
skip_stream_req = None
|
||||||
|
|
||||||
if self.is_generation:
|
if self.is_generation:
|
||||||
logits_output, next_token_ids, bid = result
|
logits_output, next_token_ids, bid = result
|
||||||
@@ -1005,6 +1001,10 @@ class Scheduler:
|
|||||||
else:
|
else:
|
||||||
# being chunked reqs' prefill is not finished
|
# being chunked reqs' prefill is not finished
|
||||||
req.is_being_chunked -= 1
|
req.is_being_chunked -= 1
|
||||||
|
# There is only at most one request being currently chunked.
|
||||||
|
# Because this request does not finish prefill,
|
||||||
|
# we don't want to stream the request currently being chunked.
|
||||||
|
skip_stream_req = req
|
||||||
|
|
||||||
if batch.next_batch_sampling_info:
|
if batch.next_batch_sampling_info:
|
||||||
batch.next_batch_sampling_info.update_regex_vocab_mask()
|
batch.next_batch_sampling_info.update_regex_vocab_mask()
|
||||||
@@ -1034,7 +1034,7 @@ class Scheduler:
|
|||||||
# being chunked reqs' prefill is not finished
|
# being chunked reqs' prefill is not finished
|
||||||
req.is_being_chunked -= 1
|
req.is_being_chunked -= 1
|
||||||
|
|
||||||
self.stream_output(batch.reqs)
|
self.stream_output(batch.reqs, skip_stream_req)
|
||||||
|
|
||||||
def process_batch_result_decode(self, batch: ScheduleBatch, result):
|
def process_batch_result_decode(self, batch: ScheduleBatch, result):
|
||||||
logits_output, next_token_ids, bid = result
|
logits_output, next_token_ids, bid = result
|
||||||
@@ -1179,7 +1179,7 @@ class Scheduler:
|
|||||||
|
|
||||||
return num_input_logprobs
|
return num_input_logprobs
|
||||||
|
|
||||||
def stream_output(self, reqs: List[Req]):
|
def stream_output(self, reqs: List[Req], skip_req: Optional[Req] = None):
|
||||||
"""Stream the output to detokenizer."""
|
"""Stream the output to detokenizer."""
|
||||||
output_rids = []
|
output_rids = []
|
||||||
output_meta_info: List[dict] = []
|
output_meta_info: List[dict] = []
|
||||||
@@ -1199,6 +1199,9 @@ class Scheduler:
|
|||||||
is_stream_iter = self.forward_ct_decode % self.stream_interval == 0
|
is_stream_iter = self.forward_ct_decode % self.stream_interval == 0
|
||||||
|
|
||||||
for req in reqs:
|
for req in reqs:
|
||||||
|
if req is skip_req:
|
||||||
|
continue
|
||||||
|
|
||||||
# TODO(lianmin): revisit this for overlap + retract + stream
|
# TODO(lianmin): revisit this for overlap + retract + stream
|
||||||
if req.finished() or (
|
if req.finished() or (
|
||||||
req.stream and (is_stream_iter or len(req.output_ids) == 1)
|
req.stream and (is_stream_iter or len(req.output_ids) == 1)
|
||||||
|
|||||||
@@ -568,6 +568,7 @@ def run_bench_serving(
|
|||||||
disable_tqdm=False,
|
disable_tqdm=False,
|
||||||
disable_stream=disable_stream,
|
disable_stream=disable_stream,
|
||||||
disable_ignore_eos=False,
|
disable_ignore_eos=False,
|
||||||
|
return_logprob=False,
|
||||||
lora_name=None,
|
lora_name=None,
|
||||||
extra_request_body=None,
|
extra_request_body=None,
|
||||||
profile=None,
|
profile=None,
|
||||||
|
|||||||
Reference in New Issue
Block a user