Support penalty in overlap mode; return logprob with chunked prefill; improve benchmark scripts (#3988)
Co-authored-by: SangBin Cho <rkooo567@gmail.com> Co-authored-by: dhou-xai <dhou@x.ai> Co-authored-by: Hanming Lu <hanming_lu@berkeley.edu>
This commit is contained in:
@@ -14,6 +14,7 @@
|
||||
"""DetokenizerManager is a process that detokenizes the token ids."""
|
||||
|
||||
import dataclasses
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import signal
|
||||
@@ -27,11 +28,16 @@ import zmq
|
||||
from sglang.srt.hf_transformers_utils import get_tokenizer
|
||||
from sglang.srt.managers.io_struct import (
|
||||
BatchEmbeddingOut,
|
||||
BatchMultimodalDecodeReq,
|
||||
BatchStrOut,
|
||||
BatchTokenIDOut,
|
||||
)
|
||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||
from sglang.srt.utils import configure_logger, get_zmq_socket
|
||||
from sglang.srt.utils import (
|
||||
configure_logger,
|
||||
get_zmq_socket,
|
||||
kill_itself_when_parent_died,
|
||||
)
|
||||
from sglang.utils import (
|
||||
TypeBasedDispatcher,
|
||||
find_printable_text,
|
||||
@@ -86,14 +92,23 @@ class DetokenizerManager:
|
||||
)
|
||||
|
||||
self.decode_status = LimitedCapacityDict(capacity=DETOKENIZER_MAX_STATES)
|
||||
self.is_dummy = server_args.load_format == "dummy"
|
||||
|
||||
self._request_dispatcher = TypeBasedDispatcher(
|
||||
[
|
||||
(BatchEmbeddingOut, self.handle_batch_embedding_out),
|
||||
(BatchTokenIDOut, self.handle_batch_token_id_out),
|
||||
(BatchMultimodalDecodeReq, self.handle_multimodal_decode_req),
|
||||
]
|
||||
)
|
||||
|
||||
def event_loop(self):
|
||||
"""The event loop that handles requests"""
|
||||
while True:
|
||||
recv_obj = self.recv_from_scheduler.recv_pyobj()
|
||||
output = self._request_dispatcher(recv_obj)
|
||||
self.send_to_tokenizer.send_pyobj(output)
|
||||
|
||||
def trim_matched_stop(
|
||||
self, output: Union[str, List[int]], finished_reason: Dict, no_stop_trim: bool
|
||||
):
|
||||
@@ -117,14 +132,6 @@ class DetokenizerManager:
|
||||
return output[:-1]
|
||||
return output
|
||||
|
||||
def event_loop(self):
|
||||
"""The event loop that handles requests"""
|
||||
|
||||
while True:
|
||||
recv_obj = self.recv_from_scheduler.recv_pyobj()
|
||||
output = self._request_dispatcher(recv_obj)
|
||||
self.send_to_tokenizer.send_pyobj(output)
|
||||
|
||||
def handle_batch_embedding_out(self, recv_obj: BatchEmbeddingOut):
|
||||
# If it is embedding model, no detokenization is needed.
|
||||
return recv_obj
|
||||
@@ -173,7 +180,6 @@ class DetokenizerManager:
|
||||
|
||||
# Incremental decoding
|
||||
output_strs = []
|
||||
finished_reqs = []
|
||||
for i in range(bs):
|
||||
try:
|
||||
s = self.decode_status[recv_obj.rids[i]]
|
||||
@@ -196,8 +202,6 @@ class DetokenizerManager:
|
||||
new_text = ""
|
||||
else:
|
||||
new_text = find_printable_text(new_text)
|
||||
else:
|
||||
finished_reqs.append(recv_obj.rids[i])
|
||||
|
||||
output_strs.append(
|
||||
self.trim_matched_stop(
|
||||
@@ -207,7 +211,7 @@ class DetokenizerManager:
|
||||
)
|
||||
)
|
||||
|
||||
out = BatchStrOut(
|
||||
return BatchStrOut(
|
||||
rids=recv_obj.rids,
|
||||
finished_reasons=recv_obj.finished_reasons,
|
||||
output_strs=output_strs,
|
||||
@@ -223,14 +227,15 @@ class DetokenizerManager:
|
||||
input_top_logprobs_idx=recv_obj.input_top_logprobs_idx,
|
||||
output_top_logprobs_val=recv_obj.output_top_logprobs_val,
|
||||
output_top_logprobs_idx=recv_obj.output_top_logprobs_idx,
|
||||
input_token_ids_logprobs_val=recv_obj.input_token_ids_logprobs_val,
|
||||
input_token_ids_logprobs_idx=recv_obj.input_token_ids_logprobs_idx,
|
||||
output_token_ids_logprobs_val=recv_obj.output_token_ids_logprobs_val,
|
||||
output_token_ids_logprobs_idx=recv_obj.output_token_ids_logprobs_idx,
|
||||
output_hidden_states=recv_obj.output_hidden_states,
|
||||
)
|
||||
|
||||
# remove decodestatus for completed requests
|
||||
for rid in finished_reqs:
|
||||
self.decode_status.pop(rid)
|
||||
|
||||
return out
|
||||
def handle_multimodal_decode_req(self, recv_obj: BatchMultimodalDecodeReq):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class LimitedCapacityDict(OrderedDict):
|
||||
@@ -250,6 +255,7 @@ def run_detokenizer_process(
|
||||
server_args: ServerArgs,
|
||||
port_args: PortArgs,
|
||||
):
|
||||
kill_itself_when_parent_died()
|
||||
setproctitle.setproctitle("sglang::detokenizer")
|
||||
configure_logger(server_args)
|
||||
parent_process = psutil.Process().parent()
|
||||
|
||||
Reference in New Issue
Block a user