Improve error handling & abort disconnected requests (#449)
This commit is contained in:
@@ -34,7 +34,7 @@ class RuntimeEndpoint(BaseBackend):
|
||||
api_key=self.api_key,
|
||||
verify=self.verify,
|
||||
)
|
||||
assert res.status_code == 200
|
||||
self._assert_success(res)
|
||||
self.model_info = res.json()
|
||||
|
||||
self.chat_template = get_chat_template_by_model_path(
|
||||
@@ -50,7 +50,7 @@ class RuntimeEndpoint(BaseBackend):
|
||||
auth_token=self.auth_token,
|
||||
verify=self.verify,
|
||||
)
|
||||
return res.status_code == 200
|
||||
self._assert_success(res)
|
||||
|
||||
def get_server_args(self):
|
||||
res = http_request(
|
||||
@@ -58,6 +58,7 @@ class RuntimeEndpoint(BaseBackend):
|
||||
auth_token=self.auth_token,
|
||||
verify=self.verify,
|
||||
)
|
||||
self._assert_success(res)
|
||||
return res.json()
|
||||
|
||||
def get_chat_template(self):
|
||||
@@ -71,7 +72,7 @@ class RuntimeEndpoint(BaseBackend):
|
||||
api_key=self.api_key,
|
||||
verify=self.verify,
|
||||
)
|
||||
assert res.status_code == 200
|
||||
self._assert_success(res)
|
||||
|
||||
def commit_lazy_operations(self, s: StreamExecutor):
|
||||
data = {"text": s.text_, "sampling_params": {"max_new_tokens": 0}}
|
||||
@@ -83,7 +84,7 @@ class RuntimeEndpoint(BaseBackend):
|
||||
api_key=self.api_key,
|
||||
verify=self.verify,
|
||||
)
|
||||
assert res.status_code == 200
|
||||
self._assert_success(res)
|
||||
|
||||
def fill_image(self, s: StreamExecutor):
|
||||
data = {"text": s.text_, "sampling_params": {"max_new_tokens": 0}}
|
||||
@@ -95,7 +96,7 @@ class RuntimeEndpoint(BaseBackend):
|
||||
api_key=self.api_key,
|
||||
verify=self.verify,
|
||||
)
|
||||
assert res.status_code == 200
|
||||
self._assert_success(res)
|
||||
|
||||
def generate(
|
||||
self,
|
||||
@@ -133,6 +134,8 @@ class RuntimeEndpoint(BaseBackend):
|
||||
api_key=self.api_key,
|
||||
verify=self.verify,
|
||||
)
|
||||
self._assert_success(res)
|
||||
|
||||
obj = res.json()
|
||||
comp = obj["text"]
|
||||
return comp, obj["meta_info"]
|
||||
@@ -167,7 +170,7 @@ class RuntimeEndpoint(BaseBackend):
|
||||
data["stream"] = True
|
||||
self._add_images(s, data)
|
||||
|
||||
response = http_request(
|
||||
res = http_request(
|
||||
self.base_url + "/generate",
|
||||
json=data,
|
||||
stream=True,
|
||||
@@ -175,10 +178,11 @@ class RuntimeEndpoint(BaseBackend):
|
||||
api_key=self.api_key,
|
||||
verify=self.verify,
|
||||
)
|
||||
self._assert_success(res)
|
||||
pos = 0
|
||||
|
||||
incomplete_text = ""
|
||||
for chunk in response.iter_lines(decode_unicode=False):
|
||||
for chunk in res.iter_lines(decode_unicode=False):
|
||||
chunk = chunk.decode("utf-8")
|
||||
if chunk and chunk.startswith("data:"):
|
||||
if chunk == "data: [DONE]":
|
||||
@@ -211,7 +215,7 @@ class RuntimeEndpoint(BaseBackend):
|
||||
api_key=self.api_key,
|
||||
verify=self.verify,
|
||||
)
|
||||
assert res.status_code == 200
|
||||
self._assert_success(res)
|
||||
prompt_len = res.json()["meta_info"]["prompt_tokens"]
|
||||
|
||||
# Compute logprob
|
||||
@@ -229,7 +233,7 @@ class RuntimeEndpoint(BaseBackend):
|
||||
api_key=self.api_key,
|
||||
verify=self.verify,
|
||||
)
|
||||
assert res.status_code == 200
|
||||
self._assert_success(res)
|
||||
obj = res.json()
|
||||
normalized_prompt_logprobs = [
|
||||
r["meta_info"]["normalized_prompt_logprob"] for r in obj
|
||||
@@ -253,9 +257,13 @@ class RuntimeEndpoint(BaseBackend):
|
||||
api_key=self.api_key,
|
||||
verify=self.verify,
|
||||
)
|
||||
assert res.status_code == 200
|
||||
self._assert_success(res)
|
||||
|
||||
def _add_images(self, s: StreamExecutor, data):
|
||||
if s.images_:
|
||||
assert len(s.images_) == 1, "Only support one image."
|
||||
data["image_data"] = s.images_[0][1]
|
||||
|
||||
def _assert_success(self, res):
|
||||
if res.status_code != 200:
|
||||
raise RuntimeError(res.json())
|
||||
@@ -191,7 +191,7 @@ class StreamExecutor:
|
||||
self.variable_event = {} # Dict[name: str -> event: threading.Event]
|
||||
self.meta_info = {} # Dict[name: str -> info: str]
|
||||
self.is_finished = False
|
||||
self.error = None
|
||||
self.error_ = None
|
||||
|
||||
# For completion
|
||||
self.text_ = "" # The full text
|
||||
@@ -300,6 +300,10 @@ class StreamExecutor:
|
||||
self.sync()
|
||||
return self.messages_
|
||||
|
||||
def error(self):
|
||||
self.sync()
|
||||
return self.error_
|
||||
|
||||
def end(self):
|
||||
if self.use_thread:
|
||||
if self.worker.is_alive():
|
||||
@@ -338,7 +342,7 @@ class StreamExecutor:
|
||||
if self.stream_var_event:
|
||||
for name in self.stream_var_event:
|
||||
self.stream_var_event[name].set()
|
||||
self.error = error
|
||||
self.error_ = error
|
||||
|
||||
if self.stream_text_event:
|
||||
self.stream_text_event.set()
|
||||
@@ -713,7 +717,7 @@ class ProgramState:
|
||||
return self.stream_executor.sync()
|
||||
|
||||
def error(self):
|
||||
return self.stream_executor.error
|
||||
return self.stream_executor.error()
|
||||
|
||||
def text_iter(self, var_name: Optional[str] = None):
|
||||
if self.stream_executor.stream:
|
||||
|
||||
@@ -31,12 +31,9 @@ class GenerateReqInput:
|
||||
|
||||
def post_init(self):
|
||||
|
||||
if self.text is None:
|
||||
assert (
|
||||
self.input_ids is not None
|
||||
), "Either text or input_ids should be provided"
|
||||
else:
|
||||
assert self.input_ids is None, "Either text or input_ids should be provided"
|
||||
if ((self.text is None and self.input_ids is None) or
|
||||
(self.text is not None and self.input_ids is not None)):
|
||||
raise ValueError("Either text or input_ids should be provided.")
|
||||
|
||||
if self.text is not None:
|
||||
is_single = isinstance(self.text, str)
|
||||
@@ -71,7 +68,8 @@ class GenerateReqInput:
|
||||
if self.rid is None:
|
||||
self.rid = [uuid.uuid4().hex for _ in range(num)]
|
||||
else:
|
||||
assert isinstance(self.rid, list)
|
||||
if not isinstance(self.rid, list):
|
||||
raise ValueError("The rid should be a list.")
|
||||
|
||||
if self.return_logprob is None:
|
||||
self.return_logprob = [False] * num
|
||||
@@ -129,6 +127,11 @@ class FlushCacheReq:
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class AbortReq:
|
||||
rid: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class DetokenizeReqInput:
|
||||
input_ids: List[int]
|
||||
|
||||
@@ -20,6 +20,7 @@ from sglang.srt.constrained.fsm_cache import FSMCache
|
||||
from sglang.srt.constrained.jump_forward import JumpForwardCache
|
||||
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
|
||||
from sglang.srt.managers.io_struct import (
|
||||
AbortReq,
|
||||
BatchTokenIDOut,
|
||||
FlushCacheReq,
|
||||
TokenizedGenerateReqInput,
|
||||
@@ -110,6 +111,8 @@ class ModelRpcServer:
|
||||
get_int_token_logit_bias(self.tokenizer, self.model_config.vocab_size)
|
||||
)
|
||||
set_random_seed(server_args.random_seed)
|
||||
|
||||
# Print info
|
||||
logger.info(
|
||||
f"Rank {self.tp_rank}: "
|
||||
f"max_total_num_token={self.max_total_num_token}, "
|
||||
@@ -160,24 +163,6 @@ class ModelRpcServer:
|
||||
self.min_new_token_ratio = min(0.2 * server_args.schedule_conservativeness, 1.0)
|
||||
self.new_token_ratio_step = (0.0001, 0.05) # (down, up)
|
||||
|
||||
def flush_cache(self):
|
||||
if len(self.forward_queue) == 0 and (
|
||||
self.running_batch is None or len(self.running_batch.reqs) == 0
|
||||
):
|
||||
self.tree_cache.reset()
|
||||
self.tree_cache_metrics = {"total": 0, "hit": 0}
|
||||
self.regex_fsm_cache.reset()
|
||||
self.req_to_token_pool.clear()
|
||||
self.token_to_kv_pool.clear()
|
||||
torch.cuda.empty_cache()
|
||||
logger.info("Cache flushed successfully!")
|
||||
else:
|
||||
warnings.warn(
|
||||
f"Cache not flushed because there are pending requests. "
|
||||
f"#queue-req: {len(self.forward_queue)}, "
|
||||
f"#running-req: {0 if self.running_batch is None else len(self.running_batch.reqs)}"
|
||||
)
|
||||
|
||||
def exposed_step(self, recv_reqs):
|
||||
if self.tp_size != 1:
|
||||
recv_reqs = obtain(recv_reqs)
|
||||
@@ -189,6 +174,8 @@ class ModelRpcServer:
|
||||
self.handle_generate_request(recv_req)
|
||||
elif isinstance(recv_req, FlushCacheReq):
|
||||
self.flush_cache()
|
||||
elif isinstance(recv_req, AbortReq):
|
||||
self.abort_request(recv_req)
|
||||
else:
|
||||
raise ValueError(f"Invalid request: {recv_req}")
|
||||
|
||||
@@ -207,9 +194,8 @@ class ModelRpcServer:
|
||||
new_batch = self.get_new_fill_batch()
|
||||
|
||||
if new_batch is not None:
|
||||
# Run new fill batch
|
||||
# Run a new fill batch
|
||||
self.forward_fill_batch(new_batch)
|
||||
|
||||
self.cache_filled_batch(new_batch)
|
||||
|
||||
if not new_batch.is_empty():
|
||||
@@ -225,14 +211,8 @@ class ModelRpcServer:
|
||||
self.num_generated_tokens += len(self.running_batch.reqs)
|
||||
self.forward_decode_batch(self.running_batch)
|
||||
|
||||
if self.running_batch.is_empty():
|
||||
self.running_batch = None
|
||||
break
|
||||
|
||||
if self.out_pyobjs and self.running_batch.reqs[0].stream:
|
||||
break
|
||||
|
||||
if self.running_batch is not None and self.tp_rank == 0:
|
||||
# Print stats
|
||||
if self.tp_rank == 0:
|
||||
if self.decode_forward_ct % 40 == 0:
|
||||
num_used = self.max_total_num_token - (
|
||||
self.token_to_kv_pool.available_size()
|
||||
@@ -250,8 +230,15 @@ class ModelRpcServer:
|
||||
f"gen throughput (token/s): {throuhgput:.2f}, "
|
||||
f"#queue-req: {len(self.forward_queue)}"
|
||||
)
|
||||
|
||||
if self.running_batch.is_empty():
|
||||
self.running_batch = None
|
||||
break
|
||||
|
||||
if self.out_pyobjs and self.running_batch.reqs[0].stream:
|
||||
break
|
||||
else:
|
||||
# check the available size
|
||||
# Check the available size
|
||||
available_size = (
|
||||
self.token_to_kv_pool.available_size()
|
||||
+ self.tree_cache.evictable_size()
|
||||
@@ -295,7 +282,7 @@ class ModelRpcServer:
|
||||
req.sampling_params.regex
|
||||
)
|
||||
|
||||
# Truncate long prompts
|
||||
# Truncate prompts that are too long
|
||||
req.input_ids = req.input_ids[: self.model_config.context_len - 1]
|
||||
req.sampling_params.max_new_tokens = min(
|
||||
req.sampling_params.max_new_tokens,
|
||||
@@ -311,6 +298,7 @@ class ModelRpcServer:
|
||||
):
|
||||
return None
|
||||
|
||||
# Compute matched prefix length
|
||||
for req in self.forward_queue:
|
||||
prefix_indices, last_node = self.tree_cache.match_prefix(req.input_ids)
|
||||
if req.return_logprob:
|
||||
@@ -383,6 +371,7 @@ class ModelRpcServer:
|
||||
if len(can_run_list) == 0:
|
||||
return None
|
||||
|
||||
# Print stats
|
||||
if self.tp_rank == 0:
|
||||
running_req = (
|
||||
0 if self.running_batch is None else len(self.running_batch.reqs)
|
||||
@@ -410,6 +399,7 @@ class ModelRpcServer:
|
||||
# f"ff_cache_avg_init_time: {self.jump_forward_cache.get_avg_init_time():.2f}s. "
|
||||
# )
|
||||
|
||||
# Return the new batch
|
||||
new_batch = Batch.init_new(
|
||||
can_run_list,
|
||||
self.req_to_token_pool,
|
||||
@@ -487,7 +477,7 @@ class ModelRpcServer:
|
||||
self.handle_finished_requests(batch)
|
||||
|
||||
def cache_filled_batch(self, batch: Batch):
|
||||
req_pool_indices_cpu = batch.req_pool_indices.cpu().tolist()
|
||||
req_pool_indices_cpu = batch.req_pool_indices.cpu().numpy()
|
||||
for i, req in enumerate(batch.reqs):
|
||||
new_prefix_indices, new_last_node = self.tree_cache.cache_req(
|
||||
token_ids=tuple(req.input_ids + req.output_ids)[:-1],
|
||||
@@ -671,6 +661,34 @@ class ModelRpcServer:
|
||||
else:
|
||||
batch.reqs = []
|
||||
|
||||
def flush_cache(self):
|
||||
if len(self.forward_queue) == 0 and (
|
||||
self.running_batch is None or len(self.running_batch.reqs) == 0
|
||||
):
|
||||
self.tree_cache.reset()
|
||||
self.tree_cache_metrics = {"total": 0, "hit": 0}
|
||||
self.regex_fsm_cache.reset()
|
||||
self.req_to_token_pool.clear()
|
||||
self.token_to_kv_pool.clear()
|
||||
torch.cuda.empty_cache()
|
||||
logger.info("Cache flushed successfully!")
|
||||
else:
|
||||
warnings.warn(
|
||||
f"Cache not flushed because there are pending requests. "
|
||||
f"#queue-req: {len(self.forward_queue)}, "
|
||||
f"#running-req: {0 if self.running_batch is None else len(self.running_batch.reqs)}"
|
||||
)
|
||||
|
||||
def abort_request(self, recv_req):
|
||||
to_del = None
|
||||
for i, req in enumerate(self.forward_queue):
|
||||
if req.rid == recv_req.rid:
|
||||
to_del = i
|
||||
break
|
||||
|
||||
if to_del is not None:
|
||||
del self.forward_queue[to_del]
|
||||
|
||||
|
||||
class ModelRpcService(rpyc.Service):
|
||||
exposed_ModelRpcServer = ModelRpcServer
|
||||
|
||||
@@ -19,6 +19,7 @@ from sglang.srt.hf_transformers_utils import (
|
||||
get_tokenizer,
|
||||
)
|
||||
from sglang.srt.managers.io_struct import (
|
||||
AbortReq,
|
||||
BatchStrOut,
|
||||
FlushCacheReq,
|
||||
GenerateReqInput,
|
||||
@@ -42,52 +43,6 @@ class ReqState:
|
||||
event: asyncio.Event
|
||||
|
||||
|
||||
global global_processor
|
||||
|
||||
|
||||
def init_global_processor(server_args: ServerArgs):
|
||||
global global_processor
|
||||
transformers.logging.set_verbosity_error()
|
||||
global_processor = get_processor(
|
||||
server_args.tokenizer_path,
|
||||
tokenizer_mode=server_args.tokenizer_mode,
|
||||
trust_remote_code=server_args.trust_remote_code,
|
||||
)
|
||||
|
||||
|
||||
def get_pixel_values(
|
||||
image_data, image_aspect_ratio=None, image_grid_pinpoints=None, processor=None
|
||||
):
|
||||
try:
|
||||
processor = processor or global_processor
|
||||
image, image_size = load_image(image_data)
|
||||
if image_size != None:
|
||||
image_hash = hash(image_data)
|
||||
pixel_values = processor.image_processor(image)["pixel_values"]
|
||||
for _ in range(len(pixel_values)):
|
||||
pixel_values[_] = pixel_values[_].astype(np.float16)
|
||||
pixel_values = np.stack(pixel_values, axis=0)
|
||||
return pixel_values, image_hash, image_size
|
||||
else:
|
||||
image_hash = hash(image_data)
|
||||
if image_aspect_ratio == "pad":
|
||||
image = expand2square(
|
||||
image,
|
||||
tuple(int(x * 255) for x in processor.image_processor.image_mean),
|
||||
)
|
||||
pixel_values = processor.image_processor(image)["pixel_values"][0]
|
||||
elif image_aspect_ratio == "anyres":
|
||||
pixel_values = process_anyres_image(
|
||||
image, processor.image_processor, image_grid_pinpoints
|
||||
)
|
||||
else:
|
||||
pixel_values = processor.image_processor(image)["pixel_values"][0]
|
||||
pixel_values = pixel_values.astype(np.float16)
|
||||
return pixel_values, image_hash, image.size
|
||||
except Exception:
|
||||
print("Exception in TokenizerManager:\n" + get_exception_traceback())
|
||||
|
||||
|
||||
class TokenizerManager:
|
||||
def __init__(
|
||||
self,
|
||||
@@ -154,10 +109,11 @@ class TokenizerManager:
|
||||
image_data, aspect_ratio, grid_pinpoints, self.processor
|
||||
)
|
||||
|
||||
async def generate_request(self, obj: GenerateReqInput):
|
||||
async def generate_request(self, obj: GenerateReqInput, request=None):
|
||||
if self.to_create_loop:
|
||||
await self.create_handle_loop()
|
||||
self.create_handle_loop()
|
||||
|
||||
obj.post_init()
|
||||
is_single = obj.is_single
|
||||
if is_single:
|
||||
rid = obj.rid
|
||||
@@ -170,7 +126,7 @@ class TokenizerManager:
|
||||
if len(input_ids) >= self.context_len:
|
||||
raise ValueError(
|
||||
f"The input ({len(input_ids)} tokens) is longer than the "
|
||||
f"model's context length ({self.context_len} tokens)"
|
||||
f"model's context length ({self.context_len} tokens)."
|
||||
)
|
||||
|
||||
sampling_params = SamplingParams(**obj.sampling_params)
|
||||
@@ -208,7 +164,14 @@ class TokenizerManager:
|
||||
self.rid_to_state[rid] = state
|
||||
|
||||
while True:
|
||||
await event.wait()
|
||||
try:
|
||||
await asyncio.wait_for(event.wait(), timeout=5)
|
||||
except asyncio.TimeoutError:
|
||||
if request is not None and await request.is_disconnected():
|
||||
self.abort_request(rid)
|
||||
raise ValueError(f"Abort request {rid}")
|
||||
continue
|
||||
|
||||
out = self.convert_logprob_style(
|
||||
state.out_list[-1],
|
||||
obj.return_logprob,
|
||||
@@ -226,7 +189,8 @@ class TokenizerManager:
|
||||
break
|
||||
event.clear()
|
||||
else:
|
||||
assert obj.stream is False
|
||||
if obj.stream:
|
||||
raise ValueError("Do not support stream for batch mode.")
|
||||
|
||||
if obj.input_ids is None:
|
||||
bs = len(obj.text)
|
||||
@@ -276,7 +240,18 @@ class TokenizerManager:
|
||||
for i in range(bs):
|
||||
rid = obj.rid[i]
|
||||
state = self.rid_to_state[rid]
|
||||
await state.event.wait()
|
||||
|
||||
while True:
|
||||
try:
|
||||
await asyncio.wait_for(state.event.wait(), timeout=5)
|
||||
break
|
||||
except asyncio.TimeoutError:
|
||||
if request is not None and await request.is_disconnected():
|
||||
for rid in obj.rid:
|
||||
self.abort_request(rid)
|
||||
raise ValueError(f"Abort request {rid}")
|
||||
continue
|
||||
|
||||
output_list.append(
|
||||
self.convert_logprob_style(
|
||||
state.out_list[-1],
|
||||
@@ -290,11 +265,16 @@ class TokenizerManager:
|
||||
|
||||
yield output_list
|
||||
|
||||
async def flush_cache(self):
|
||||
flush_cache_req = FlushCacheReq()
|
||||
self.send_to_router.send_pyobj(flush_cache_req)
|
||||
def flush_cache(self):
|
||||
req = FlushCacheReq()
|
||||
self.send_to_router.send_pyobj(req)
|
||||
|
||||
async def create_handle_loop(self):
|
||||
def abort_request(self, rid):
|
||||
del self.rid_to_state[rid]
|
||||
req = AbortReq(rid)
|
||||
self.send_to_router.send_pyobj(req)
|
||||
|
||||
def create_handle_loop(self):
|
||||
self.to_create_loop = False
|
||||
loop = asyncio.get_event_loop()
|
||||
loop.create_task(self.handle_loop())
|
||||
@@ -305,17 +285,20 @@ class TokenizerManager:
|
||||
|
||||
if isinstance(recv_obj, BatchStrOut):
|
||||
for i, rid in enumerate(recv_obj.rids):
|
||||
state = self.rid_to_state.get(rid, None)
|
||||
if state is None:
|
||||
continue
|
||||
|
||||
recv_obj.meta_info[i]["id"] = rid
|
||||
out_dict = {
|
||||
"text": recv_obj.output_str[i],
|
||||
"meta_info": recv_obj.meta_info[i],
|
||||
}
|
||||
state = self.rid_to_state[rid]
|
||||
state.out_list.append(out_dict)
|
||||
state.finished = recv_obj.finished[i]
|
||||
state.event.set()
|
||||
else:
|
||||
raise ValueError(f"Invalid object: {recv_obj}")
|
||||
raise ValueError(f"Invalid object: {recv_obj}.")
|
||||
|
||||
def convert_logprob_style(
|
||||
self, ret, return_logprob, top_logprobs_num, return_text_in_logprobs
|
||||
@@ -356,3 +339,50 @@ class TokenizerManager:
|
||||
if t:
|
||||
top_logprobs[i] = self.detokenize_logprob_tokens(t, decode_to_text)
|
||||
return top_logprobs
|
||||
|
||||
|
||||
|
||||
global global_processor
|
||||
|
||||
|
||||
def init_global_processor(server_args: ServerArgs):
|
||||
global global_processor
|
||||
transformers.logging.set_verbosity_error()
|
||||
global_processor = get_processor(
|
||||
server_args.tokenizer_path,
|
||||
tokenizer_mode=server_args.tokenizer_mode,
|
||||
trust_remote_code=server_args.trust_remote_code,
|
||||
)
|
||||
|
||||
|
||||
def get_pixel_values(
|
||||
image_data, image_aspect_ratio=None, image_grid_pinpoints=None, processor=None
|
||||
):
|
||||
try:
|
||||
processor = processor or global_processor
|
||||
image, image_size = load_image(image_data)
|
||||
if image_size != None:
|
||||
image_hash = hash(image_data)
|
||||
pixel_values = processor.image_processor(image)["pixel_values"]
|
||||
for _ in range(len(pixel_values)):
|
||||
pixel_values[_] = pixel_values[_].astype(np.float16)
|
||||
pixel_values = np.stack(pixel_values, axis=0)
|
||||
return pixel_values, image_hash, image_size
|
||||
else:
|
||||
image_hash = hash(image_data)
|
||||
if image_aspect_ratio == "pad":
|
||||
image = expand2square(
|
||||
image,
|
||||
tuple(int(x * 255) for x in processor.image_processor.image_mean),
|
||||
)
|
||||
pixel_values = processor.image_processor(image)["pixel_values"][0]
|
||||
elif image_aspect_ratio == "anyres":
|
||||
pixel_values = process_anyres_image(
|
||||
image, processor.image_processor, image_grid_pinpoints
|
||||
)
|
||||
else:
|
||||
pixel_values = processor.image_processor(image)["pixel_values"][0]
|
||||
pixel_values = pixel_values.astype(np.float16)
|
||||
return pixel_values, image_hash, image.size
|
||||
except Exception:
|
||||
print("Exception in TokenizerManager:\n" + get_exception_traceback())
|
||||
@@ -335,7 +335,7 @@ def to_openai_style_logprobs(
|
||||
ret_logprobs.tokens.append(token_text)
|
||||
ret_logprobs.token_logprobs.append(logprob)
|
||||
|
||||
# Not Supported yet
|
||||
# Not supported yet
|
||||
ret_logprobs.text_offset.append(-1)
|
||||
|
||||
def append_top_logprobs(top_logprobs):
|
||||
|
||||
@@ -10,6 +10,7 @@ import sys
|
||||
import threading
|
||||
import time
|
||||
from typing import List, Optional, Union
|
||||
from http import HTTPStatus
|
||||
|
||||
# Fix a bug of Python threading
|
||||
setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
|
||||
@@ -73,7 +74,7 @@ async def get_server_args():
|
||||
|
||||
@app.get("/flush_cache")
|
||||
async def flush_cache():
|
||||
await tokenizer_manager.flush_cache()
|
||||
tokenizer_manager.flush_cache()
|
||||
return Response(
|
||||
content="Cache flushed.\nPlease check backend logs for more details. "
|
||||
"(When there are running or waiting requests, the operation will not be performed.)\n",
|
||||
@@ -81,24 +82,25 @@ async def flush_cache():
|
||||
)
|
||||
|
||||
|
||||
async def generate_request(obj: GenerateReqInput):
|
||||
obj.post_init()
|
||||
|
||||
async def generate_request(obj: GenerateReqInput, request: Request):
|
||||
if obj.stream:
|
||||
|
||||
async def stream_results():
|
||||
async for out in tokenizer_manager.generate_request(obj):
|
||||
try:
|
||||
async for out in tokenizer_manager.generate_request(obj, request):
|
||||
yield f"data: {json.dumps(out, ensure_ascii=False)}\n\n"
|
||||
except ValueError as e:
|
||||
out = {"error": {"message": str(e)}}
|
||||
yield f"data: {json.dumps(out, ensure_ascii=False)}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
return StreamingResponse(stream_results(), media_type="text/event-stream")
|
||||
|
||||
try:
|
||||
ret = await tokenizer_manager.generate_request(obj).__anext__()
|
||||
return ret
|
||||
except ValueError as e:
|
||||
print(f"Error: {e}")
|
||||
return JSONResponse({"error": str(e)}, status_code=400)
|
||||
else:
|
||||
try:
|
||||
ret = await tokenizer_manager.generate_request(obj, request).__anext__()
|
||||
return ret
|
||||
except ValueError as e:
|
||||
return JSONResponse({"error": {"message": str(e)}},
|
||||
status_code=HTTPStatus.BAD_REQUEST)
|
||||
|
||||
app.post("/generate")(generate_request)
|
||||
app.put("/generate")(generate_request)
|
||||
@@ -186,6 +188,7 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
|
||||
if server_args.api_key and server_args.api_key != "":
|
||||
app.add_middleware(APIKeyValidatorMiddleware, api_key=server_args.api_key)
|
||||
|
||||
# Send a warmup request
|
||||
def _wait_and_warmup():
|
||||
headers = {}
|
||||
url = server_args.url()
|
||||
@@ -228,6 +231,8 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
|
||||
|
||||
t = threading.Thread(target=_wait_and_warmup)
|
||||
t.start()
|
||||
|
||||
# Listen for requests
|
||||
try:
|
||||
uvicorn.run(
|
||||
app,
|
||||
|
||||
@@ -9,7 +9,7 @@ import requests
|
||||
from sglang.backend.openai import OpenAI
|
||||
from sglang.backend.runtime_endpoint import RuntimeEndpoint
|
||||
from sglang.global_config import global_config
|
||||
from sglang.srt.utils import get_exception_traceback
|
||||
from sglang.utils import get_exception_traceback
|
||||
|
||||
|
||||
def call_generate_lightllm(prompt, temperature, max_tokens, stop=None, url=None):
|
||||
|
||||
@@ -93,8 +93,12 @@ def http_request(
|
||||
data = None
|
||||
else:
|
||||
data = bytes(dumps(json), encoding="utf-8")
|
||||
resp = urllib.request.urlopen(req, data=data, cafile=verify)
|
||||
return HttpResponse(resp)
|
||||
|
||||
try:
|
||||
resp = urllib.request.urlopen(req, data=data, cafile=verify)
|
||||
return HttpResponse(resp)
|
||||
except urllib.error.HTTPError as e:
|
||||
return HttpResponse(e)
|
||||
|
||||
|
||||
def encode_image_base64(image_path):
|
||||
|
||||
Reference in New Issue
Block a user