Improve error handling & abort disconnected requests (#449)
This commit is contained in:
@@ -34,7 +34,7 @@ class RuntimeEndpoint(BaseBackend):
|
|||||||
api_key=self.api_key,
|
api_key=self.api_key,
|
||||||
verify=self.verify,
|
verify=self.verify,
|
||||||
)
|
)
|
||||||
assert res.status_code == 200
|
self._assert_success(res)
|
||||||
self.model_info = res.json()
|
self.model_info = res.json()
|
||||||
|
|
||||||
self.chat_template = get_chat_template_by_model_path(
|
self.chat_template = get_chat_template_by_model_path(
|
||||||
@@ -50,7 +50,7 @@ class RuntimeEndpoint(BaseBackend):
|
|||||||
auth_token=self.auth_token,
|
auth_token=self.auth_token,
|
||||||
verify=self.verify,
|
verify=self.verify,
|
||||||
)
|
)
|
||||||
return res.status_code == 200
|
self._assert_success(res)
|
||||||
|
|
||||||
def get_server_args(self):
|
def get_server_args(self):
|
||||||
res = http_request(
|
res = http_request(
|
||||||
@@ -58,6 +58,7 @@ class RuntimeEndpoint(BaseBackend):
|
|||||||
auth_token=self.auth_token,
|
auth_token=self.auth_token,
|
||||||
verify=self.verify,
|
verify=self.verify,
|
||||||
)
|
)
|
||||||
|
self._assert_success(res)
|
||||||
return res.json()
|
return res.json()
|
||||||
|
|
||||||
def get_chat_template(self):
|
def get_chat_template(self):
|
||||||
@@ -71,7 +72,7 @@ class RuntimeEndpoint(BaseBackend):
|
|||||||
api_key=self.api_key,
|
api_key=self.api_key,
|
||||||
verify=self.verify,
|
verify=self.verify,
|
||||||
)
|
)
|
||||||
assert res.status_code == 200
|
self._assert_success(res)
|
||||||
|
|
||||||
def commit_lazy_operations(self, s: StreamExecutor):
|
def commit_lazy_operations(self, s: StreamExecutor):
|
||||||
data = {"text": s.text_, "sampling_params": {"max_new_tokens": 0}}
|
data = {"text": s.text_, "sampling_params": {"max_new_tokens": 0}}
|
||||||
@@ -83,7 +84,7 @@ class RuntimeEndpoint(BaseBackend):
|
|||||||
api_key=self.api_key,
|
api_key=self.api_key,
|
||||||
verify=self.verify,
|
verify=self.verify,
|
||||||
)
|
)
|
||||||
assert res.status_code == 200
|
self._assert_success(res)
|
||||||
|
|
||||||
def fill_image(self, s: StreamExecutor):
|
def fill_image(self, s: StreamExecutor):
|
||||||
data = {"text": s.text_, "sampling_params": {"max_new_tokens": 0}}
|
data = {"text": s.text_, "sampling_params": {"max_new_tokens": 0}}
|
||||||
@@ -95,7 +96,7 @@ class RuntimeEndpoint(BaseBackend):
|
|||||||
api_key=self.api_key,
|
api_key=self.api_key,
|
||||||
verify=self.verify,
|
verify=self.verify,
|
||||||
)
|
)
|
||||||
assert res.status_code == 200
|
self._assert_success(res)
|
||||||
|
|
||||||
def generate(
|
def generate(
|
||||||
self,
|
self,
|
||||||
@@ -133,6 +134,8 @@ class RuntimeEndpoint(BaseBackend):
|
|||||||
api_key=self.api_key,
|
api_key=self.api_key,
|
||||||
verify=self.verify,
|
verify=self.verify,
|
||||||
)
|
)
|
||||||
|
self._assert_success(res)
|
||||||
|
|
||||||
obj = res.json()
|
obj = res.json()
|
||||||
comp = obj["text"]
|
comp = obj["text"]
|
||||||
return comp, obj["meta_info"]
|
return comp, obj["meta_info"]
|
||||||
@@ -167,7 +170,7 @@ class RuntimeEndpoint(BaseBackend):
|
|||||||
data["stream"] = True
|
data["stream"] = True
|
||||||
self._add_images(s, data)
|
self._add_images(s, data)
|
||||||
|
|
||||||
response = http_request(
|
res = http_request(
|
||||||
self.base_url + "/generate",
|
self.base_url + "/generate",
|
||||||
json=data,
|
json=data,
|
||||||
stream=True,
|
stream=True,
|
||||||
@@ -175,10 +178,11 @@ class RuntimeEndpoint(BaseBackend):
|
|||||||
api_key=self.api_key,
|
api_key=self.api_key,
|
||||||
verify=self.verify,
|
verify=self.verify,
|
||||||
)
|
)
|
||||||
|
self._assert_success(res)
|
||||||
pos = 0
|
pos = 0
|
||||||
|
|
||||||
incomplete_text = ""
|
incomplete_text = ""
|
||||||
for chunk in response.iter_lines(decode_unicode=False):
|
for chunk in res.iter_lines(decode_unicode=False):
|
||||||
chunk = chunk.decode("utf-8")
|
chunk = chunk.decode("utf-8")
|
||||||
if chunk and chunk.startswith("data:"):
|
if chunk and chunk.startswith("data:"):
|
||||||
if chunk == "data: [DONE]":
|
if chunk == "data: [DONE]":
|
||||||
@@ -211,7 +215,7 @@ class RuntimeEndpoint(BaseBackend):
|
|||||||
api_key=self.api_key,
|
api_key=self.api_key,
|
||||||
verify=self.verify,
|
verify=self.verify,
|
||||||
)
|
)
|
||||||
assert res.status_code == 200
|
self._assert_success(res)
|
||||||
prompt_len = res.json()["meta_info"]["prompt_tokens"]
|
prompt_len = res.json()["meta_info"]["prompt_tokens"]
|
||||||
|
|
||||||
# Compute logprob
|
# Compute logprob
|
||||||
@@ -229,7 +233,7 @@ class RuntimeEndpoint(BaseBackend):
|
|||||||
api_key=self.api_key,
|
api_key=self.api_key,
|
||||||
verify=self.verify,
|
verify=self.verify,
|
||||||
)
|
)
|
||||||
assert res.status_code == 200
|
self._assert_success(res)
|
||||||
obj = res.json()
|
obj = res.json()
|
||||||
normalized_prompt_logprobs = [
|
normalized_prompt_logprobs = [
|
||||||
r["meta_info"]["normalized_prompt_logprob"] for r in obj
|
r["meta_info"]["normalized_prompt_logprob"] for r in obj
|
||||||
@@ -253,9 +257,13 @@ class RuntimeEndpoint(BaseBackend):
|
|||||||
api_key=self.api_key,
|
api_key=self.api_key,
|
||||||
verify=self.verify,
|
verify=self.verify,
|
||||||
)
|
)
|
||||||
assert res.status_code == 200
|
self._assert_success(res)
|
||||||
|
|
||||||
def _add_images(self, s: StreamExecutor, data):
|
def _add_images(self, s: StreamExecutor, data):
|
||||||
if s.images_:
|
if s.images_:
|
||||||
assert len(s.images_) == 1, "Only support one image."
|
assert len(s.images_) == 1, "Only support one image."
|
||||||
data["image_data"] = s.images_[0][1]
|
data["image_data"] = s.images_[0][1]
|
||||||
|
|
||||||
|
def _assert_success(self, res):
|
||||||
|
if res.status_code != 200:
|
||||||
|
raise RuntimeError(res.json())
|
||||||
@@ -191,7 +191,7 @@ class StreamExecutor:
|
|||||||
self.variable_event = {} # Dict[name: str -> event: threading.Event]
|
self.variable_event = {} # Dict[name: str -> event: threading.Event]
|
||||||
self.meta_info = {} # Dict[name: str -> info: str]
|
self.meta_info = {} # Dict[name: str -> info: str]
|
||||||
self.is_finished = False
|
self.is_finished = False
|
||||||
self.error = None
|
self.error_ = None
|
||||||
|
|
||||||
# For completion
|
# For completion
|
||||||
self.text_ = "" # The full text
|
self.text_ = "" # The full text
|
||||||
@@ -300,6 +300,10 @@ class StreamExecutor:
|
|||||||
self.sync()
|
self.sync()
|
||||||
return self.messages_
|
return self.messages_
|
||||||
|
|
||||||
|
def error(self):
|
||||||
|
self.sync()
|
||||||
|
return self.error_
|
||||||
|
|
||||||
def end(self):
|
def end(self):
|
||||||
if self.use_thread:
|
if self.use_thread:
|
||||||
if self.worker.is_alive():
|
if self.worker.is_alive():
|
||||||
@@ -338,7 +342,7 @@ class StreamExecutor:
|
|||||||
if self.stream_var_event:
|
if self.stream_var_event:
|
||||||
for name in self.stream_var_event:
|
for name in self.stream_var_event:
|
||||||
self.stream_var_event[name].set()
|
self.stream_var_event[name].set()
|
||||||
self.error = error
|
self.error_ = error
|
||||||
|
|
||||||
if self.stream_text_event:
|
if self.stream_text_event:
|
||||||
self.stream_text_event.set()
|
self.stream_text_event.set()
|
||||||
@@ -713,7 +717,7 @@ class ProgramState:
|
|||||||
return self.stream_executor.sync()
|
return self.stream_executor.sync()
|
||||||
|
|
||||||
def error(self):
|
def error(self):
|
||||||
return self.stream_executor.error
|
return self.stream_executor.error()
|
||||||
|
|
||||||
def text_iter(self, var_name: Optional[str] = None):
|
def text_iter(self, var_name: Optional[str] = None):
|
||||||
if self.stream_executor.stream:
|
if self.stream_executor.stream:
|
||||||
|
|||||||
@@ -31,12 +31,9 @@ class GenerateReqInput:
|
|||||||
|
|
||||||
def post_init(self):
|
def post_init(self):
|
||||||
|
|
||||||
if self.text is None:
|
if ((self.text is None and self.input_ids is None) or
|
||||||
assert (
|
(self.text is not None and self.input_ids is not None)):
|
||||||
self.input_ids is not None
|
raise ValueError("Either text or input_ids should be provided.")
|
||||||
), "Either text or input_ids should be provided"
|
|
||||||
else:
|
|
||||||
assert self.input_ids is None, "Either text or input_ids should be provided"
|
|
||||||
|
|
||||||
if self.text is not None:
|
if self.text is not None:
|
||||||
is_single = isinstance(self.text, str)
|
is_single = isinstance(self.text, str)
|
||||||
@@ -71,7 +68,8 @@ class GenerateReqInput:
|
|||||||
if self.rid is None:
|
if self.rid is None:
|
||||||
self.rid = [uuid.uuid4().hex for _ in range(num)]
|
self.rid = [uuid.uuid4().hex for _ in range(num)]
|
||||||
else:
|
else:
|
||||||
assert isinstance(self.rid, list)
|
if not isinstance(self.rid, list):
|
||||||
|
raise ValueError("The rid should be a list.")
|
||||||
|
|
||||||
if self.return_logprob is None:
|
if self.return_logprob is None:
|
||||||
self.return_logprob = [False] * num
|
self.return_logprob = [False] * num
|
||||||
@@ -129,6 +127,11 @@ class FlushCacheReq:
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AbortReq:
|
||||||
|
rid: str
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class DetokenizeReqInput:
|
class DetokenizeReqInput:
|
||||||
input_ids: List[int]
|
input_ids: List[int]
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ from sglang.srt.constrained.fsm_cache import FSMCache
|
|||||||
from sglang.srt.constrained.jump_forward import JumpForwardCache
|
from sglang.srt.constrained.jump_forward import JumpForwardCache
|
||||||
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
|
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
|
||||||
from sglang.srt.managers.io_struct import (
|
from sglang.srt.managers.io_struct import (
|
||||||
|
AbortReq,
|
||||||
BatchTokenIDOut,
|
BatchTokenIDOut,
|
||||||
FlushCacheReq,
|
FlushCacheReq,
|
||||||
TokenizedGenerateReqInput,
|
TokenizedGenerateReqInput,
|
||||||
@@ -110,6 +111,8 @@ class ModelRpcServer:
|
|||||||
get_int_token_logit_bias(self.tokenizer, self.model_config.vocab_size)
|
get_int_token_logit_bias(self.tokenizer, self.model_config.vocab_size)
|
||||||
)
|
)
|
||||||
set_random_seed(server_args.random_seed)
|
set_random_seed(server_args.random_seed)
|
||||||
|
|
||||||
|
# Print info
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Rank {self.tp_rank}: "
|
f"Rank {self.tp_rank}: "
|
||||||
f"max_total_num_token={self.max_total_num_token}, "
|
f"max_total_num_token={self.max_total_num_token}, "
|
||||||
@@ -160,24 +163,6 @@ class ModelRpcServer:
|
|||||||
self.min_new_token_ratio = min(0.2 * server_args.schedule_conservativeness, 1.0)
|
self.min_new_token_ratio = min(0.2 * server_args.schedule_conservativeness, 1.0)
|
||||||
self.new_token_ratio_step = (0.0001, 0.05) # (down, up)
|
self.new_token_ratio_step = (0.0001, 0.05) # (down, up)
|
||||||
|
|
||||||
def flush_cache(self):
|
|
||||||
if len(self.forward_queue) == 0 and (
|
|
||||||
self.running_batch is None or len(self.running_batch.reqs) == 0
|
|
||||||
):
|
|
||||||
self.tree_cache.reset()
|
|
||||||
self.tree_cache_metrics = {"total": 0, "hit": 0}
|
|
||||||
self.regex_fsm_cache.reset()
|
|
||||||
self.req_to_token_pool.clear()
|
|
||||||
self.token_to_kv_pool.clear()
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
logger.info("Cache flushed successfully!")
|
|
||||||
else:
|
|
||||||
warnings.warn(
|
|
||||||
f"Cache not flushed because there are pending requests. "
|
|
||||||
f"#queue-req: {len(self.forward_queue)}, "
|
|
||||||
f"#running-req: {0 if self.running_batch is None else len(self.running_batch.reqs)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
def exposed_step(self, recv_reqs):
|
def exposed_step(self, recv_reqs):
|
||||||
if self.tp_size != 1:
|
if self.tp_size != 1:
|
||||||
recv_reqs = obtain(recv_reqs)
|
recv_reqs = obtain(recv_reqs)
|
||||||
@@ -189,6 +174,8 @@ class ModelRpcServer:
|
|||||||
self.handle_generate_request(recv_req)
|
self.handle_generate_request(recv_req)
|
||||||
elif isinstance(recv_req, FlushCacheReq):
|
elif isinstance(recv_req, FlushCacheReq):
|
||||||
self.flush_cache()
|
self.flush_cache()
|
||||||
|
elif isinstance(recv_req, AbortReq):
|
||||||
|
self.abort_request(recv_req)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Invalid request: {recv_req}")
|
raise ValueError(f"Invalid request: {recv_req}")
|
||||||
|
|
||||||
@@ -207,9 +194,8 @@ class ModelRpcServer:
|
|||||||
new_batch = self.get_new_fill_batch()
|
new_batch = self.get_new_fill_batch()
|
||||||
|
|
||||||
if new_batch is not None:
|
if new_batch is not None:
|
||||||
# Run new fill batch
|
# Run a new fill batch
|
||||||
self.forward_fill_batch(new_batch)
|
self.forward_fill_batch(new_batch)
|
||||||
|
|
||||||
self.cache_filled_batch(new_batch)
|
self.cache_filled_batch(new_batch)
|
||||||
|
|
||||||
if not new_batch.is_empty():
|
if not new_batch.is_empty():
|
||||||
@@ -225,14 +211,8 @@ class ModelRpcServer:
|
|||||||
self.num_generated_tokens += len(self.running_batch.reqs)
|
self.num_generated_tokens += len(self.running_batch.reqs)
|
||||||
self.forward_decode_batch(self.running_batch)
|
self.forward_decode_batch(self.running_batch)
|
||||||
|
|
||||||
if self.running_batch.is_empty():
|
# Print stats
|
||||||
self.running_batch = None
|
if self.tp_rank == 0:
|
||||||
break
|
|
||||||
|
|
||||||
if self.out_pyobjs and self.running_batch.reqs[0].stream:
|
|
||||||
break
|
|
||||||
|
|
||||||
if self.running_batch is not None and self.tp_rank == 0:
|
|
||||||
if self.decode_forward_ct % 40 == 0:
|
if self.decode_forward_ct % 40 == 0:
|
||||||
num_used = self.max_total_num_token - (
|
num_used = self.max_total_num_token - (
|
||||||
self.token_to_kv_pool.available_size()
|
self.token_to_kv_pool.available_size()
|
||||||
@@ -250,8 +230,15 @@ class ModelRpcServer:
|
|||||||
f"gen throughput (token/s): {throuhgput:.2f}, "
|
f"gen throughput (token/s): {throuhgput:.2f}, "
|
||||||
f"#queue-req: {len(self.forward_queue)}"
|
f"#queue-req: {len(self.forward_queue)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self.running_batch.is_empty():
|
||||||
|
self.running_batch = None
|
||||||
|
break
|
||||||
|
|
||||||
|
if self.out_pyobjs and self.running_batch.reqs[0].stream:
|
||||||
|
break
|
||||||
else:
|
else:
|
||||||
# check the available size
|
# Check the available size
|
||||||
available_size = (
|
available_size = (
|
||||||
self.token_to_kv_pool.available_size()
|
self.token_to_kv_pool.available_size()
|
||||||
+ self.tree_cache.evictable_size()
|
+ self.tree_cache.evictable_size()
|
||||||
@@ -295,7 +282,7 @@ class ModelRpcServer:
|
|||||||
req.sampling_params.regex
|
req.sampling_params.regex
|
||||||
)
|
)
|
||||||
|
|
||||||
# Truncate long prompts
|
# Truncate prompts that are too long
|
||||||
req.input_ids = req.input_ids[: self.model_config.context_len - 1]
|
req.input_ids = req.input_ids[: self.model_config.context_len - 1]
|
||||||
req.sampling_params.max_new_tokens = min(
|
req.sampling_params.max_new_tokens = min(
|
||||||
req.sampling_params.max_new_tokens,
|
req.sampling_params.max_new_tokens,
|
||||||
@@ -311,6 +298,7 @@ class ModelRpcServer:
|
|||||||
):
|
):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
# Compute matched prefix length
|
||||||
for req in self.forward_queue:
|
for req in self.forward_queue:
|
||||||
prefix_indices, last_node = self.tree_cache.match_prefix(req.input_ids)
|
prefix_indices, last_node = self.tree_cache.match_prefix(req.input_ids)
|
||||||
if req.return_logprob:
|
if req.return_logprob:
|
||||||
@@ -383,6 +371,7 @@ class ModelRpcServer:
|
|||||||
if len(can_run_list) == 0:
|
if len(can_run_list) == 0:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
# Print stats
|
||||||
if self.tp_rank == 0:
|
if self.tp_rank == 0:
|
||||||
running_req = (
|
running_req = (
|
||||||
0 if self.running_batch is None else len(self.running_batch.reqs)
|
0 if self.running_batch is None else len(self.running_batch.reqs)
|
||||||
@@ -410,6 +399,7 @@ class ModelRpcServer:
|
|||||||
# f"ff_cache_avg_init_time: {self.jump_forward_cache.get_avg_init_time():.2f}s. "
|
# f"ff_cache_avg_init_time: {self.jump_forward_cache.get_avg_init_time():.2f}s. "
|
||||||
# )
|
# )
|
||||||
|
|
||||||
|
# Return the new batch
|
||||||
new_batch = Batch.init_new(
|
new_batch = Batch.init_new(
|
||||||
can_run_list,
|
can_run_list,
|
||||||
self.req_to_token_pool,
|
self.req_to_token_pool,
|
||||||
@@ -487,7 +477,7 @@ class ModelRpcServer:
|
|||||||
self.handle_finished_requests(batch)
|
self.handle_finished_requests(batch)
|
||||||
|
|
||||||
def cache_filled_batch(self, batch: Batch):
|
def cache_filled_batch(self, batch: Batch):
|
||||||
req_pool_indices_cpu = batch.req_pool_indices.cpu().tolist()
|
req_pool_indices_cpu = batch.req_pool_indices.cpu().numpy()
|
||||||
for i, req in enumerate(batch.reqs):
|
for i, req in enumerate(batch.reqs):
|
||||||
new_prefix_indices, new_last_node = self.tree_cache.cache_req(
|
new_prefix_indices, new_last_node = self.tree_cache.cache_req(
|
||||||
token_ids=tuple(req.input_ids + req.output_ids)[:-1],
|
token_ids=tuple(req.input_ids + req.output_ids)[:-1],
|
||||||
@@ -671,6 +661,34 @@ class ModelRpcServer:
|
|||||||
else:
|
else:
|
||||||
batch.reqs = []
|
batch.reqs = []
|
||||||
|
|
||||||
|
def flush_cache(self):
|
||||||
|
if len(self.forward_queue) == 0 and (
|
||||||
|
self.running_batch is None or len(self.running_batch.reqs) == 0
|
||||||
|
):
|
||||||
|
self.tree_cache.reset()
|
||||||
|
self.tree_cache_metrics = {"total": 0, "hit": 0}
|
||||||
|
self.regex_fsm_cache.reset()
|
||||||
|
self.req_to_token_pool.clear()
|
||||||
|
self.token_to_kv_pool.clear()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
logger.info("Cache flushed successfully!")
|
||||||
|
else:
|
||||||
|
warnings.warn(
|
||||||
|
f"Cache not flushed because there are pending requests. "
|
||||||
|
f"#queue-req: {len(self.forward_queue)}, "
|
||||||
|
f"#running-req: {0 if self.running_batch is None else len(self.running_batch.reqs)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def abort_request(self, recv_req):
|
||||||
|
to_del = None
|
||||||
|
for i, req in enumerate(self.forward_queue):
|
||||||
|
if req.rid == recv_req.rid:
|
||||||
|
to_del = i
|
||||||
|
break
|
||||||
|
|
||||||
|
if to_del is not None:
|
||||||
|
del self.forward_queue[to_del]
|
||||||
|
|
||||||
|
|
||||||
class ModelRpcService(rpyc.Service):
|
class ModelRpcService(rpyc.Service):
|
||||||
exposed_ModelRpcServer = ModelRpcServer
|
exposed_ModelRpcServer = ModelRpcServer
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ from sglang.srt.hf_transformers_utils import (
|
|||||||
get_tokenizer,
|
get_tokenizer,
|
||||||
)
|
)
|
||||||
from sglang.srt.managers.io_struct import (
|
from sglang.srt.managers.io_struct import (
|
||||||
|
AbortReq,
|
||||||
BatchStrOut,
|
BatchStrOut,
|
||||||
FlushCacheReq,
|
FlushCacheReq,
|
||||||
GenerateReqInput,
|
GenerateReqInput,
|
||||||
@@ -42,52 +43,6 @@ class ReqState:
|
|||||||
event: asyncio.Event
|
event: asyncio.Event
|
||||||
|
|
||||||
|
|
||||||
global global_processor
|
|
||||||
|
|
||||||
|
|
||||||
def init_global_processor(server_args: ServerArgs):
|
|
||||||
global global_processor
|
|
||||||
transformers.logging.set_verbosity_error()
|
|
||||||
global_processor = get_processor(
|
|
||||||
server_args.tokenizer_path,
|
|
||||||
tokenizer_mode=server_args.tokenizer_mode,
|
|
||||||
trust_remote_code=server_args.trust_remote_code,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def get_pixel_values(
|
|
||||||
image_data, image_aspect_ratio=None, image_grid_pinpoints=None, processor=None
|
|
||||||
):
|
|
||||||
try:
|
|
||||||
processor = processor or global_processor
|
|
||||||
image, image_size = load_image(image_data)
|
|
||||||
if image_size != None:
|
|
||||||
image_hash = hash(image_data)
|
|
||||||
pixel_values = processor.image_processor(image)["pixel_values"]
|
|
||||||
for _ in range(len(pixel_values)):
|
|
||||||
pixel_values[_] = pixel_values[_].astype(np.float16)
|
|
||||||
pixel_values = np.stack(pixel_values, axis=0)
|
|
||||||
return pixel_values, image_hash, image_size
|
|
||||||
else:
|
|
||||||
image_hash = hash(image_data)
|
|
||||||
if image_aspect_ratio == "pad":
|
|
||||||
image = expand2square(
|
|
||||||
image,
|
|
||||||
tuple(int(x * 255) for x in processor.image_processor.image_mean),
|
|
||||||
)
|
|
||||||
pixel_values = processor.image_processor(image)["pixel_values"][0]
|
|
||||||
elif image_aspect_ratio == "anyres":
|
|
||||||
pixel_values = process_anyres_image(
|
|
||||||
image, processor.image_processor, image_grid_pinpoints
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
pixel_values = processor.image_processor(image)["pixel_values"][0]
|
|
||||||
pixel_values = pixel_values.astype(np.float16)
|
|
||||||
return pixel_values, image_hash, image.size
|
|
||||||
except Exception:
|
|
||||||
print("Exception in TokenizerManager:\n" + get_exception_traceback())
|
|
||||||
|
|
||||||
|
|
||||||
class TokenizerManager:
|
class TokenizerManager:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -154,10 +109,11 @@ class TokenizerManager:
|
|||||||
image_data, aspect_ratio, grid_pinpoints, self.processor
|
image_data, aspect_ratio, grid_pinpoints, self.processor
|
||||||
)
|
)
|
||||||
|
|
||||||
async def generate_request(self, obj: GenerateReqInput):
|
async def generate_request(self, obj: GenerateReqInput, request=None):
|
||||||
if self.to_create_loop:
|
if self.to_create_loop:
|
||||||
await self.create_handle_loop()
|
self.create_handle_loop()
|
||||||
|
|
||||||
|
obj.post_init()
|
||||||
is_single = obj.is_single
|
is_single = obj.is_single
|
||||||
if is_single:
|
if is_single:
|
||||||
rid = obj.rid
|
rid = obj.rid
|
||||||
@@ -170,7 +126,7 @@ class TokenizerManager:
|
|||||||
if len(input_ids) >= self.context_len:
|
if len(input_ids) >= self.context_len:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"The input ({len(input_ids)} tokens) is longer than the "
|
f"The input ({len(input_ids)} tokens) is longer than the "
|
||||||
f"model's context length ({self.context_len} tokens)"
|
f"model's context length ({self.context_len} tokens)."
|
||||||
)
|
)
|
||||||
|
|
||||||
sampling_params = SamplingParams(**obj.sampling_params)
|
sampling_params = SamplingParams(**obj.sampling_params)
|
||||||
@@ -208,7 +164,14 @@ class TokenizerManager:
|
|||||||
self.rid_to_state[rid] = state
|
self.rid_to_state[rid] = state
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
await event.wait()
|
try:
|
||||||
|
await asyncio.wait_for(event.wait(), timeout=5)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
if request is not None and await request.is_disconnected():
|
||||||
|
self.abort_request(rid)
|
||||||
|
raise ValueError(f"Abort request {rid}")
|
||||||
|
continue
|
||||||
|
|
||||||
out = self.convert_logprob_style(
|
out = self.convert_logprob_style(
|
||||||
state.out_list[-1],
|
state.out_list[-1],
|
||||||
obj.return_logprob,
|
obj.return_logprob,
|
||||||
@@ -226,7 +189,8 @@ class TokenizerManager:
|
|||||||
break
|
break
|
||||||
event.clear()
|
event.clear()
|
||||||
else:
|
else:
|
||||||
assert obj.stream is False
|
if obj.stream:
|
||||||
|
raise ValueError("Do not support stream for batch mode.")
|
||||||
|
|
||||||
if obj.input_ids is None:
|
if obj.input_ids is None:
|
||||||
bs = len(obj.text)
|
bs = len(obj.text)
|
||||||
@@ -276,7 +240,18 @@ class TokenizerManager:
|
|||||||
for i in range(bs):
|
for i in range(bs):
|
||||||
rid = obj.rid[i]
|
rid = obj.rid[i]
|
||||||
state = self.rid_to_state[rid]
|
state = self.rid_to_state[rid]
|
||||||
await state.event.wait()
|
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(state.event.wait(), timeout=5)
|
||||||
|
break
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
if request is not None and await request.is_disconnected():
|
||||||
|
for rid in obj.rid:
|
||||||
|
self.abort_request(rid)
|
||||||
|
raise ValueError(f"Abort request {rid}")
|
||||||
|
continue
|
||||||
|
|
||||||
output_list.append(
|
output_list.append(
|
||||||
self.convert_logprob_style(
|
self.convert_logprob_style(
|
||||||
state.out_list[-1],
|
state.out_list[-1],
|
||||||
@@ -290,11 +265,16 @@ class TokenizerManager:
|
|||||||
|
|
||||||
yield output_list
|
yield output_list
|
||||||
|
|
||||||
async def flush_cache(self):
|
def flush_cache(self):
|
||||||
flush_cache_req = FlushCacheReq()
|
req = FlushCacheReq()
|
||||||
self.send_to_router.send_pyobj(flush_cache_req)
|
self.send_to_router.send_pyobj(req)
|
||||||
|
|
||||||
async def create_handle_loop(self):
|
def abort_request(self, rid):
|
||||||
|
del self.rid_to_state[rid]
|
||||||
|
req = AbortReq(rid)
|
||||||
|
self.send_to_router.send_pyobj(req)
|
||||||
|
|
||||||
|
def create_handle_loop(self):
|
||||||
self.to_create_loop = False
|
self.to_create_loop = False
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
loop.create_task(self.handle_loop())
|
loop.create_task(self.handle_loop())
|
||||||
@@ -305,17 +285,20 @@ class TokenizerManager:
|
|||||||
|
|
||||||
if isinstance(recv_obj, BatchStrOut):
|
if isinstance(recv_obj, BatchStrOut):
|
||||||
for i, rid in enumerate(recv_obj.rids):
|
for i, rid in enumerate(recv_obj.rids):
|
||||||
|
state = self.rid_to_state.get(rid, None)
|
||||||
|
if state is None:
|
||||||
|
continue
|
||||||
|
|
||||||
recv_obj.meta_info[i]["id"] = rid
|
recv_obj.meta_info[i]["id"] = rid
|
||||||
out_dict = {
|
out_dict = {
|
||||||
"text": recv_obj.output_str[i],
|
"text": recv_obj.output_str[i],
|
||||||
"meta_info": recv_obj.meta_info[i],
|
"meta_info": recv_obj.meta_info[i],
|
||||||
}
|
}
|
||||||
state = self.rid_to_state[rid]
|
|
||||||
state.out_list.append(out_dict)
|
state.out_list.append(out_dict)
|
||||||
state.finished = recv_obj.finished[i]
|
state.finished = recv_obj.finished[i]
|
||||||
state.event.set()
|
state.event.set()
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Invalid object: {recv_obj}")
|
raise ValueError(f"Invalid object: {recv_obj}.")
|
||||||
|
|
||||||
def convert_logprob_style(
|
def convert_logprob_style(
|
||||||
self, ret, return_logprob, top_logprobs_num, return_text_in_logprobs
|
self, ret, return_logprob, top_logprobs_num, return_text_in_logprobs
|
||||||
@@ -356,3 +339,50 @@ class TokenizerManager:
|
|||||||
if t:
|
if t:
|
||||||
top_logprobs[i] = self.detokenize_logprob_tokens(t, decode_to_text)
|
top_logprobs[i] = self.detokenize_logprob_tokens(t, decode_to_text)
|
||||||
return top_logprobs
|
return top_logprobs
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
global global_processor
|
||||||
|
|
||||||
|
|
||||||
|
def init_global_processor(server_args: ServerArgs):
|
||||||
|
global global_processor
|
||||||
|
transformers.logging.set_verbosity_error()
|
||||||
|
global_processor = get_processor(
|
||||||
|
server_args.tokenizer_path,
|
||||||
|
tokenizer_mode=server_args.tokenizer_mode,
|
||||||
|
trust_remote_code=server_args.trust_remote_code,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_pixel_values(
|
||||||
|
image_data, image_aspect_ratio=None, image_grid_pinpoints=None, processor=None
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
processor = processor or global_processor
|
||||||
|
image, image_size = load_image(image_data)
|
||||||
|
if image_size != None:
|
||||||
|
image_hash = hash(image_data)
|
||||||
|
pixel_values = processor.image_processor(image)["pixel_values"]
|
||||||
|
for _ in range(len(pixel_values)):
|
||||||
|
pixel_values[_] = pixel_values[_].astype(np.float16)
|
||||||
|
pixel_values = np.stack(pixel_values, axis=0)
|
||||||
|
return pixel_values, image_hash, image_size
|
||||||
|
else:
|
||||||
|
image_hash = hash(image_data)
|
||||||
|
if image_aspect_ratio == "pad":
|
||||||
|
image = expand2square(
|
||||||
|
image,
|
||||||
|
tuple(int(x * 255) for x in processor.image_processor.image_mean),
|
||||||
|
)
|
||||||
|
pixel_values = processor.image_processor(image)["pixel_values"][0]
|
||||||
|
elif image_aspect_ratio == "anyres":
|
||||||
|
pixel_values = process_anyres_image(
|
||||||
|
image, processor.image_processor, image_grid_pinpoints
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
pixel_values = processor.image_processor(image)["pixel_values"][0]
|
||||||
|
pixel_values = pixel_values.astype(np.float16)
|
||||||
|
return pixel_values, image_hash, image.size
|
||||||
|
except Exception:
|
||||||
|
print("Exception in TokenizerManager:\n" + get_exception_traceback())
|
||||||
@@ -335,7 +335,7 @@ def to_openai_style_logprobs(
|
|||||||
ret_logprobs.tokens.append(token_text)
|
ret_logprobs.tokens.append(token_text)
|
||||||
ret_logprobs.token_logprobs.append(logprob)
|
ret_logprobs.token_logprobs.append(logprob)
|
||||||
|
|
||||||
# Not Supported yet
|
# Not supported yet
|
||||||
ret_logprobs.text_offset.append(-1)
|
ret_logprobs.text_offset.append(-1)
|
||||||
|
|
||||||
def append_top_logprobs(top_logprobs):
|
def append_top_logprobs(top_logprobs):
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ import sys
|
|||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
from typing import List, Optional, Union
|
from typing import List, Optional, Union
|
||||||
|
from http import HTTPStatus
|
||||||
|
|
||||||
# Fix a bug of Python threading
|
# Fix a bug of Python threading
|
||||||
setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
|
setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
|
||||||
@@ -73,7 +74,7 @@ async def get_server_args():
|
|||||||
|
|
||||||
@app.get("/flush_cache")
|
@app.get("/flush_cache")
|
||||||
async def flush_cache():
|
async def flush_cache():
|
||||||
await tokenizer_manager.flush_cache()
|
tokenizer_manager.flush_cache()
|
||||||
return Response(
|
return Response(
|
||||||
content="Cache flushed.\nPlease check backend logs for more details. "
|
content="Cache flushed.\nPlease check backend logs for more details. "
|
||||||
"(When there are running or waiting requests, the operation will not be performed.)\n",
|
"(When there are running or waiting requests, the operation will not be performed.)\n",
|
||||||
@@ -81,24 +82,25 @@ async def flush_cache():
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def generate_request(obj: GenerateReqInput):
|
async def generate_request(obj: GenerateReqInput, request: Request):
|
||||||
obj.post_init()
|
|
||||||
|
|
||||||
if obj.stream:
|
if obj.stream:
|
||||||
|
|
||||||
async def stream_results():
|
async def stream_results():
|
||||||
async for out in tokenizer_manager.generate_request(obj):
|
try:
|
||||||
|
async for out in tokenizer_manager.generate_request(obj, request):
|
||||||
|
yield f"data: {json.dumps(out, ensure_ascii=False)}\n\n"
|
||||||
|
except ValueError as e:
|
||||||
|
out = {"error": {"message": str(e)}}
|
||||||
yield f"data: {json.dumps(out, ensure_ascii=False)}\n\n"
|
yield f"data: {json.dumps(out, ensure_ascii=False)}\n\n"
|
||||||
yield "data: [DONE]\n\n"
|
yield "data: [DONE]\n\n"
|
||||||
|
|
||||||
return StreamingResponse(stream_results(), media_type="text/event-stream")
|
return StreamingResponse(stream_results(), media_type="text/event-stream")
|
||||||
|
else:
|
||||||
try:
|
try:
|
||||||
ret = await tokenizer_manager.generate_request(obj).__anext__()
|
ret = await tokenizer_manager.generate_request(obj, request).__anext__()
|
||||||
return ret
|
return ret
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
print(f"Error: {e}")
|
return JSONResponse({"error": {"message": str(e)}},
|
||||||
return JSONResponse({"error": str(e)}, status_code=400)
|
status_code=HTTPStatus.BAD_REQUEST)
|
||||||
|
|
||||||
app.post("/generate")(generate_request)
|
app.post("/generate")(generate_request)
|
||||||
app.put("/generate")(generate_request)
|
app.put("/generate")(generate_request)
|
||||||
@@ -186,6 +188,7 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
|
|||||||
if server_args.api_key and server_args.api_key != "":
|
if server_args.api_key and server_args.api_key != "":
|
||||||
app.add_middleware(APIKeyValidatorMiddleware, api_key=server_args.api_key)
|
app.add_middleware(APIKeyValidatorMiddleware, api_key=server_args.api_key)
|
||||||
|
|
||||||
|
# Send a warmup request
|
||||||
def _wait_and_warmup():
|
def _wait_and_warmup():
|
||||||
headers = {}
|
headers = {}
|
||||||
url = server_args.url()
|
url = server_args.url()
|
||||||
@@ -228,6 +231,8 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
|
|||||||
|
|
||||||
t = threading.Thread(target=_wait_and_warmup)
|
t = threading.Thread(target=_wait_and_warmup)
|
||||||
t.start()
|
t.start()
|
||||||
|
|
||||||
|
# Listen for requests
|
||||||
try:
|
try:
|
||||||
uvicorn.run(
|
uvicorn.run(
|
||||||
app,
|
app,
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ import requests
|
|||||||
from sglang.backend.openai import OpenAI
|
from sglang.backend.openai import OpenAI
|
||||||
from sglang.backend.runtime_endpoint import RuntimeEndpoint
|
from sglang.backend.runtime_endpoint import RuntimeEndpoint
|
||||||
from sglang.global_config import global_config
|
from sglang.global_config import global_config
|
||||||
from sglang.srt.utils import get_exception_traceback
|
from sglang.utils import get_exception_traceback
|
||||||
|
|
||||||
|
|
||||||
def call_generate_lightllm(prompt, temperature, max_tokens, stop=None, url=None):
|
def call_generate_lightllm(prompt, temperature, max_tokens, stop=None, url=None):
|
||||||
|
|||||||
@@ -93,8 +93,12 @@ def http_request(
|
|||||||
data = None
|
data = None
|
||||||
else:
|
else:
|
||||||
data = bytes(dumps(json), encoding="utf-8")
|
data = bytes(dumps(json), encoding="utf-8")
|
||||||
resp = urllib.request.urlopen(req, data=data, cafile=verify)
|
|
||||||
return HttpResponse(resp)
|
try:
|
||||||
|
resp = urllib.request.urlopen(req, data=data, cafile=verify)
|
||||||
|
return HttpResponse(resp)
|
||||||
|
except urllib.error.HTTPError as e:
|
||||||
|
return HttpResponse(e)
|
||||||
|
|
||||||
|
|
||||||
def encode_image_base64(image_path):
|
def encode_image_base64(image_path):
|
||||||
|
|||||||
Reference in New Issue
Block a user