Improve error handling & abort disconnected requests (#449)
This commit is contained in:
@@ -19,6 +19,7 @@ from sglang.srt.hf_transformers_utils import (
|
||||
get_tokenizer,
|
||||
)
|
||||
from sglang.srt.managers.io_struct import (
|
||||
AbortReq,
|
||||
BatchStrOut,
|
||||
FlushCacheReq,
|
||||
GenerateReqInput,
|
||||
@@ -42,52 +43,6 @@ class ReqState:
|
||||
event: asyncio.Event
|
||||
|
||||
|
||||
global global_processor
|
||||
|
||||
|
||||
def init_global_processor(server_args: ServerArgs):
|
||||
global global_processor
|
||||
transformers.logging.set_verbosity_error()
|
||||
global_processor = get_processor(
|
||||
server_args.tokenizer_path,
|
||||
tokenizer_mode=server_args.tokenizer_mode,
|
||||
trust_remote_code=server_args.trust_remote_code,
|
||||
)
|
||||
|
||||
|
||||
def get_pixel_values(
|
||||
image_data, image_aspect_ratio=None, image_grid_pinpoints=None, processor=None
|
||||
):
|
||||
try:
|
||||
processor = processor or global_processor
|
||||
image, image_size = load_image(image_data)
|
||||
if image_size != None:
|
||||
image_hash = hash(image_data)
|
||||
pixel_values = processor.image_processor(image)["pixel_values"]
|
||||
for _ in range(len(pixel_values)):
|
||||
pixel_values[_] = pixel_values[_].astype(np.float16)
|
||||
pixel_values = np.stack(pixel_values, axis=0)
|
||||
return pixel_values, image_hash, image_size
|
||||
else:
|
||||
image_hash = hash(image_data)
|
||||
if image_aspect_ratio == "pad":
|
||||
image = expand2square(
|
||||
image,
|
||||
tuple(int(x * 255) for x in processor.image_processor.image_mean),
|
||||
)
|
||||
pixel_values = processor.image_processor(image)["pixel_values"][0]
|
||||
elif image_aspect_ratio == "anyres":
|
||||
pixel_values = process_anyres_image(
|
||||
image, processor.image_processor, image_grid_pinpoints
|
||||
)
|
||||
else:
|
||||
pixel_values = processor.image_processor(image)["pixel_values"][0]
|
||||
pixel_values = pixel_values.astype(np.float16)
|
||||
return pixel_values, image_hash, image.size
|
||||
except Exception:
|
||||
print("Exception in TokenizerManager:\n" + get_exception_traceback())
|
||||
|
||||
|
||||
class TokenizerManager:
|
||||
def __init__(
|
||||
self,
|
||||
@@ -154,10 +109,11 @@ class TokenizerManager:
|
||||
image_data, aspect_ratio, grid_pinpoints, self.processor
|
||||
)
|
||||
|
||||
async def generate_request(self, obj: GenerateReqInput):
|
||||
async def generate_request(self, obj: GenerateReqInput, request=None):
|
||||
if self.to_create_loop:
|
||||
await self.create_handle_loop()
|
||||
self.create_handle_loop()
|
||||
|
||||
obj.post_init()
|
||||
is_single = obj.is_single
|
||||
if is_single:
|
||||
rid = obj.rid
|
||||
@@ -170,7 +126,7 @@ class TokenizerManager:
|
||||
if len(input_ids) >= self.context_len:
|
||||
raise ValueError(
|
||||
f"The input ({len(input_ids)} tokens) is longer than the "
|
||||
f"model's context length ({self.context_len} tokens)"
|
||||
f"model's context length ({self.context_len} tokens)."
|
||||
)
|
||||
|
||||
sampling_params = SamplingParams(**obj.sampling_params)
|
||||
@@ -208,7 +164,14 @@ class TokenizerManager:
|
||||
self.rid_to_state[rid] = state
|
||||
|
||||
while True:
|
||||
await event.wait()
|
||||
try:
|
||||
await asyncio.wait_for(event.wait(), timeout=5)
|
||||
except asyncio.TimeoutError:
|
||||
if request is not None and await request.is_disconnected():
|
||||
self.abort_request(rid)
|
||||
raise ValueError(f"Abort request {rid}")
|
||||
continue
|
||||
|
||||
out = self.convert_logprob_style(
|
||||
state.out_list[-1],
|
||||
obj.return_logprob,
|
||||
@@ -226,7 +189,8 @@ class TokenizerManager:
|
||||
break
|
||||
event.clear()
|
||||
else:
|
||||
assert obj.stream is False
|
||||
if obj.stream:
|
||||
raise ValueError("Do not support stream for batch mode.")
|
||||
|
||||
if obj.input_ids is None:
|
||||
bs = len(obj.text)
|
||||
@@ -276,7 +240,18 @@ class TokenizerManager:
|
||||
for i in range(bs):
|
||||
rid = obj.rid[i]
|
||||
state = self.rid_to_state[rid]
|
||||
await state.event.wait()
|
||||
|
||||
while True:
|
||||
try:
|
||||
await asyncio.wait_for(state.event.wait(), timeout=5)
|
||||
break
|
||||
except asyncio.TimeoutError:
|
||||
if request is not None and await request.is_disconnected():
|
||||
for rid in obj.rid:
|
||||
self.abort_request(rid)
|
||||
raise ValueError(f"Abort request {rid}")
|
||||
continue
|
||||
|
||||
output_list.append(
|
||||
self.convert_logprob_style(
|
||||
state.out_list[-1],
|
||||
@@ -290,11 +265,16 @@ class TokenizerManager:
|
||||
|
||||
yield output_list
|
||||
|
||||
async def flush_cache(self):
|
||||
flush_cache_req = FlushCacheReq()
|
||||
self.send_to_router.send_pyobj(flush_cache_req)
|
||||
def flush_cache(self):
|
||||
req = FlushCacheReq()
|
||||
self.send_to_router.send_pyobj(req)
|
||||
|
||||
async def create_handle_loop(self):
|
||||
def abort_request(self, rid):
|
||||
del self.rid_to_state[rid]
|
||||
req = AbortReq(rid)
|
||||
self.send_to_router.send_pyobj(req)
|
||||
|
||||
def create_handle_loop(self):
|
||||
self.to_create_loop = False
|
||||
loop = asyncio.get_event_loop()
|
||||
loop.create_task(self.handle_loop())
|
||||
@@ -305,17 +285,20 @@ class TokenizerManager:
|
||||
|
||||
if isinstance(recv_obj, BatchStrOut):
|
||||
for i, rid in enumerate(recv_obj.rids):
|
||||
state = self.rid_to_state.get(rid, None)
|
||||
if state is None:
|
||||
continue
|
||||
|
||||
recv_obj.meta_info[i]["id"] = rid
|
||||
out_dict = {
|
||||
"text": recv_obj.output_str[i],
|
||||
"meta_info": recv_obj.meta_info[i],
|
||||
}
|
||||
state = self.rid_to_state[rid]
|
||||
state.out_list.append(out_dict)
|
||||
state.finished = recv_obj.finished[i]
|
||||
state.event.set()
|
||||
else:
|
||||
raise ValueError(f"Invalid object: {recv_obj}")
|
||||
raise ValueError(f"Invalid object: {recv_obj}.")
|
||||
|
||||
def convert_logprob_style(
|
||||
self, ret, return_logprob, top_logprobs_num, return_text_in_logprobs
|
||||
@@ -356,3 +339,50 @@ class TokenizerManager:
|
||||
if t:
|
||||
top_logprobs[i] = self.detokenize_logprob_tokens(t, decode_to_text)
|
||||
return top_logprobs
|
||||
|
||||
|
||||
|
||||
global global_processor
|
||||
|
||||
|
||||
def init_global_processor(server_args: ServerArgs):
|
||||
global global_processor
|
||||
transformers.logging.set_verbosity_error()
|
||||
global_processor = get_processor(
|
||||
server_args.tokenizer_path,
|
||||
tokenizer_mode=server_args.tokenizer_mode,
|
||||
trust_remote_code=server_args.trust_remote_code,
|
||||
)
|
||||
|
||||
|
||||
def get_pixel_values(
|
||||
image_data, image_aspect_ratio=None, image_grid_pinpoints=None, processor=None
|
||||
):
|
||||
try:
|
||||
processor = processor or global_processor
|
||||
image, image_size = load_image(image_data)
|
||||
if image_size != None:
|
||||
image_hash = hash(image_data)
|
||||
pixel_values = processor.image_processor(image)["pixel_values"]
|
||||
for _ in range(len(pixel_values)):
|
||||
pixel_values[_] = pixel_values[_].astype(np.float16)
|
||||
pixel_values = np.stack(pixel_values, axis=0)
|
||||
return pixel_values, image_hash, image_size
|
||||
else:
|
||||
image_hash = hash(image_data)
|
||||
if image_aspect_ratio == "pad":
|
||||
image = expand2square(
|
||||
image,
|
||||
tuple(int(x * 255) for x in processor.image_processor.image_mean),
|
||||
)
|
||||
pixel_values = processor.image_processor(image)["pixel_values"][0]
|
||||
elif image_aspect_ratio == "anyres":
|
||||
pixel_values = process_anyres_image(
|
||||
image, processor.image_processor, image_grid_pinpoints
|
||||
)
|
||||
else:
|
||||
pixel_values = processor.image_processor(image)["pixel_values"][0]
|
||||
pixel_values = pixel_values.astype(np.float16)
|
||||
return pixel_values, image_hash, image.size
|
||||
except Exception:
|
||||
print("Exception in TokenizerManager:\n" + get_exception_traceback())
|
||||
Reference in New Issue
Block a user