diff --git a/python/sglang/backend/runtime_endpoint.py b/python/sglang/backend/runtime_endpoint.py index efa0cd5a4..42b4312f6 100644 --- a/python/sglang/backend/runtime_endpoint.py +++ b/python/sglang/backend/runtime_endpoint.py @@ -34,7 +34,7 @@ class RuntimeEndpoint(BaseBackend): api_key=self.api_key, verify=self.verify, ) - assert res.status_code == 200 + self._assert_success(res) self.model_info = res.json() self.chat_template = get_chat_template_by_model_path( @@ -50,7 +50,7 @@ class RuntimeEndpoint(BaseBackend): auth_token=self.auth_token, verify=self.verify, ) - return res.status_code == 200 + self._assert_success(res) def get_server_args(self): res = http_request( @@ -58,6 +58,7 @@ class RuntimeEndpoint(BaseBackend): auth_token=self.auth_token, verify=self.verify, ) + self._assert_success(res) return res.json() def get_chat_template(self): @@ -71,7 +72,7 @@ class RuntimeEndpoint(BaseBackend): api_key=self.api_key, verify=self.verify, ) - assert res.status_code == 200 + self._assert_success(res) def commit_lazy_operations(self, s: StreamExecutor): data = {"text": s.text_, "sampling_params": {"max_new_tokens": 0}} @@ -83,7 +84,7 @@ class RuntimeEndpoint(BaseBackend): api_key=self.api_key, verify=self.verify, ) - assert res.status_code == 200 + self._assert_success(res) def fill_image(self, s: StreamExecutor): data = {"text": s.text_, "sampling_params": {"max_new_tokens": 0}} @@ -95,7 +96,7 @@ class RuntimeEndpoint(BaseBackend): api_key=self.api_key, verify=self.verify, ) - assert res.status_code == 200 + self._assert_success(res) def generate( self, @@ -133,6 +134,8 @@ class RuntimeEndpoint(BaseBackend): api_key=self.api_key, verify=self.verify, ) + self._assert_success(res) + obj = res.json() comp = obj["text"] return comp, obj["meta_info"] @@ -167,7 +170,7 @@ class RuntimeEndpoint(BaseBackend): data["stream"] = True self._add_images(s, data) - response = http_request( + res = http_request( self.base_url + "/generate", json=data, stream=True, @@ -175,10 +178,11 @@ class RuntimeEndpoint(BaseBackend): api_key=self.api_key, verify=self.verify, ) + self._assert_success(res) pos = 0 incomplete_text = "" - for chunk in response.iter_lines(decode_unicode=False): + for chunk in res.iter_lines(decode_unicode=False): chunk = chunk.decode("utf-8") if chunk and chunk.startswith("data:"): if chunk == "data: [DONE]": @@ -211,7 +215,7 @@ class RuntimeEndpoint(BaseBackend): api_key=self.api_key, verify=self.verify, ) - assert res.status_code == 200 + self._assert_success(res) prompt_len = res.json()["meta_info"]["prompt_tokens"] # Compute logprob @@ -229,7 +233,7 @@ class RuntimeEndpoint(BaseBackend): api_key=self.api_key, verify=self.verify, ) - assert res.status_code == 200 + self._assert_success(res) obj = res.json() normalized_prompt_logprobs = [ r["meta_info"]["normalized_prompt_logprob"] for r in obj @@ -253,9 +257,13 @@ class RuntimeEndpoint(BaseBackend): api_key=self.api_key, verify=self.verify, ) - assert res.status_code == 200 + self._assert_success(res) def _add_images(self, s: StreamExecutor, data): if s.images_: assert len(s.images_) == 1, "Only support one image." data["image_data"] = s.images_[0][1] + + def _assert_success(self, res): + if res.status_code != 200: + raise RuntimeError(res.json()) \ No newline at end of file diff --git a/python/sglang/lang/interpreter.py b/python/sglang/lang/interpreter.py index cac2b714d..7c638cba0 100644 --- a/python/sglang/lang/interpreter.py +++ b/python/sglang/lang/interpreter.py @@ -191,7 +191,7 @@ class StreamExecutor: self.variable_event = {} # Dict[name: str -> event: threading.Event] self.meta_info = {} # Dict[name: str -> info: str] self.is_finished = False - self.error = None + self.error_ = None # For completion self.text_ = "" # The full text @@ -300,6 +300,10 @@ class StreamExecutor: self.sync() return self.messages_ + def error(self): + self.sync() + return self.error_ + def end(self): if self.use_thread: if self.worker.is_alive(): @@ -338,7 +342,7 @@ class StreamExecutor: if self.stream_var_event: for name in self.stream_var_event: self.stream_var_event[name].set() - self.error = error + self.error_ = error if self.stream_text_event: self.stream_text_event.set() @@ -713,7 +717,7 @@ class ProgramState: return self.stream_executor.sync() def error(self): - return self.stream_executor.error + return self.stream_executor.error() def text_iter(self, var_name: Optional[str] = None): if self.stream_executor.stream: diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index a9f9ab2a1..3439747f5 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -31,12 +31,9 @@ class GenerateReqInput: def post_init(self): - if self.text is None: - assert ( - self.input_ids is not None - ), "Either text or input_ids should be provided" - else: - assert self.input_ids is None, "Either text or input_ids should be provided" + if ((self.text is None and self.input_ids is None) or + (self.text is not None and self.input_ids is not None)): + raise ValueError("Either text or input_ids should be provided.") if self.text is not None: is_single = isinstance(self.text, str) @@ -71,7 +68,8 @@ class GenerateReqInput: if self.rid is None: self.rid = [uuid.uuid4().hex for _ in range(num)] else: - assert isinstance(self.rid, list) + if not isinstance(self.rid, list): + raise ValueError("The rid should be a list.") if self.return_logprob is None: self.return_logprob = [False] * num @@ -129,6 +127,11 @@ class FlushCacheReq: pass +@dataclass +class AbortReq: + rid: str + + @dataclass class DetokenizeReqInput: input_ids: List[int] diff --git a/python/sglang/srt/managers/router/model_rpc.py b/python/sglang/srt/managers/router/model_rpc.py index 660f09f3e..00564676b 100644 --- a/python/sglang/srt/managers/router/model_rpc.py +++ b/python/sglang/srt/managers/router/model_rpc.py @@ -20,6 +20,7 @@ from sglang.srt.constrained.fsm_cache import FSMCache from sglang.srt.constrained.jump_forward import JumpForwardCache from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer from sglang.srt.managers.io_struct import ( + AbortReq, BatchTokenIDOut, FlushCacheReq, TokenizedGenerateReqInput, @@ -110,6 +111,8 @@ class ModelRpcServer: get_int_token_logit_bias(self.tokenizer, self.model_config.vocab_size) ) set_random_seed(server_args.random_seed) + + # Print info logger.info( f"Rank {self.tp_rank}: " f"max_total_num_token={self.max_total_num_token}, " @@ -160,24 +163,6 @@ class ModelRpcServer: self.min_new_token_ratio = min(0.2 * server_args.schedule_conservativeness, 1.0) self.new_token_ratio_step = (0.0001, 0.05) # (down, up) - def flush_cache(self): - if len(self.forward_queue) == 0 and ( - self.running_batch is None or len(self.running_batch.reqs) == 0 - ): - self.tree_cache.reset() - self.tree_cache_metrics = {"total": 0, "hit": 0} - self.regex_fsm_cache.reset() - self.req_to_token_pool.clear() - self.token_to_kv_pool.clear() - torch.cuda.empty_cache() - logger.info("Cache flushed successfully!") - else: - warnings.warn( - f"Cache not flushed because there are pending requests. " - f"#queue-req: {len(self.forward_queue)}, " - f"#running-req: {0 if self.running_batch is None else len(self.running_batch.reqs)}" - ) - def exposed_step(self, recv_reqs): if self.tp_size != 1: recv_reqs = obtain(recv_reqs) @@ -189,6 +174,8 @@ class ModelRpcServer: self.handle_generate_request(recv_req) elif isinstance(recv_req, FlushCacheReq): self.flush_cache() + elif isinstance(recv_req, AbortReq): + self.abort_request(recv_req) else: raise ValueError(f"Invalid request: {recv_req}") @@ -207,9 +194,8 @@ class ModelRpcServer: new_batch = self.get_new_fill_batch() if new_batch is not None: - # Run new fill batch + # Run a new fill batch self.forward_fill_batch(new_batch) - self.cache_filled_batch(new_batch) if not new_batch.is_empty(): @@ -225,14 +211,8 @@ class ModelRpcServer: self.num_generated_tokens += len(self.running_batch.reqs) self.forward_decode_batch(self.running_batch) - if self.running_batch.is_empty(): - self.running_batch = None - break - - if self.out_pyobjs and self.running_batch.reqs[0].stream: - break - - if self.running_batch is not None and self.tp_rank == 0: + # Print stats + if self.tp_rank == 0: if self.decode_forward_ct % 40 == 0: num_used = self.max_total_num_token - ( self.token_to_kv_pool.available_size() @@ -250,8 +230,15 @@ class ModelRpcServer: f"gen throughput (token/s): {throuhgput:.2f}, " f"#queue-req: {len(self.forward_queue)}" ) + + if self.running_batch.is_empty(): + self.running_batch = None + break + + if self.out_pyobjs and self.running_batch.reqs[0].stream: + break else: - # check the available size + # Check the available size available_size = ( self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size() @@ -295,7 +282,7 @@ class ModelRpcServer: req.sampling_params.regex ) - # Truncate long prompts + # Truncate prompts that are too long req.input_ids = req.input_ids[: self.model_config.context_len - 1] req.sampling_params.max_new_tokens = min( req.sampling_params.max_new_tokens, @@ -311,6 +298,7 @@ class ModelRpcServer: ): return None + # Compute matched prefix length for req in self.forward_queue: prefix_indices, last_node = self.tree_cache.match_prefix(req.input_ids) if req.return_logprob: @@ -383,6 +371,7 @@ class ModelRpcServer: if len(can_run_list) == 0: return None + # Print stats if self.tp_rank == 0: running_req = ( 0 if self.running_batch is None else len(self.running_batch.reqs) @@ -410,6 +399,7 @@ class ModelRpcServer: # f"ff_cache_avg_init_time: {self.jump_forward_cache.get_avg_init_time():.2f}s. " # ) + # Return the new batch new_batch = Batch.init_new( can_run_list, self.req_to_token_pool, @@ -487,7 +477,7 @@ class ModelRpcServer: self.handle_finished_requests(batch) def cache_filled_batch(self, batch: Batch): - req_pool_indices_cpu = batch.req_pool_indices.cpu().tolist() + req_pool_indices_cpu = batch.req_pool_indices.cpu().numpy() for i, req in enumerate(batch.reqs): new_prefix_indices, new_last_node = self.tree_cache.cache_req( token_ids=tuple(req.input_ids + req.output_ids)[:-1], @@ -671,6 +661,34 @@ class ModelRpcServer: else: batch.reqs = [] + def flush_cache(self): + if len(self.forward_queue) == 0 and ( + self.running_batch is None or len(self.running_batch.reqs) == 0 + ): + self.tree_cache.reset() + self.tree_cache_metrics = {"total": 0, "hit": 0} + self.regex_fsm_cache.reset() + self.req_to_token_pool.clear() + self.token_to_kv_pool.clear() + torch.cuda.empty_cache() + logger.info("Cache flushed successfully!") + else: + warnings.warn( + f"Cache not flushed because there are pending requests. " + f"#queue-req: {len(self.forward_queue)}, " + f"#running-req: {0 if self.running_batch is None else len(self.running_batch.reqs)}" + ) + + def abort_request(self, recv_req): + to_del = None + for i, req in enumerate(self.forward_queue): + if req.rid == recv_req.rid: + to_del = i + break + + if to_del is not None: + del self.forward_queue[to_del] + class ModelRpcService(rpyc.Service): exposed_ModelRpcServer = ModelRpcServer diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 8cc27f849..69ed86792 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -19,6 +19,7 @@ from sglang.srt.hf_transformers_utils import ( get_tokenizer, ) from sglang.srt.managers.io_struct import ( + AbortReq, BatchStrOut, FlushCacheReq, GenerateReqInput, @@ -42,52 +43,6 @@ class ReqState: event: asyncio.Event -global global_processor - - -def init_global_processor(server_args: ServerArgs): - global global_processor - transformers.logging.set_verbosity_error() - global_processor = get_processor( - server_args.tokenizer_path, - tokenizer_mode=server_args.tokenizer_mode, - trust_remote_code=server_args.trust_remote_code, - ) - - -def get_pixel_values( - image_data, image_aspect_ratio=None, image_grid_pinpoints=None, processor=None -): - try: - processor = processor or global_processor - image, image_size = load_image(image_data) - if image_size != None: - image_hash = hash(image_data) - pixel_values = processor.image_processor(image)["pixel_values"] - for _ in range(len(pixel_values)): - pixel_values[_] = pixel_values[_].astype(np.float16) - pixel_values = np.stack(pixel_values, axis=0) - return pixel_values, image_hash, image_size - else: - image_hash = hash(image_data) - if image_aspect_ratio == "pad": - image = expand2square( - image, - tuple(int(x * 255) for x in processor.image_processor.image_mean), - ) - pixel_values = processor.image_processor(image)["pixel_values"][0] - elif image_aspect_ratio == "anyres": - pixel_values = process_anyres_image( - image, processor.image_processor, image_grid_pinpoints - ) - else: - pixel_values = processor.image_processor(image)["pixel_values"][0] - pixel_values = pixel_values.astype(np.float16) - return pixel_values, image_hash, image.size - except Exception: - print("Exception in TokenizerManager:\n" + get_exception_traceback()) - - class TokenizerManager: def __init__( self, @@ -154,10 +109,11 @@ class TokenizerManager: image_data, aspect_ratio, grid_pinpoints, self.processor ) - async def generate_request(self, obj: GenerateReqInput): + async def generate_request(self, obj: GenerateReqInput, request=None): if self.to_create_loop: - await self.create_handle_loop() + self.create_handle_loop() + obj.post_init() is_single = obj.is_single if is_single: rid = obj.rid @@ -170,7 +126,7 @@ class TokenizerManager: if len(input_ids) >= self.context_len: raise ValueError( f"The input ({len(input_ids)} tokens) is longer than the " - f"model's context length ({self.context_len} tokens)" + f"model's context length ({self.context_len} tokens)." ) sampling_params = SamplingParams(**obj.sampling_params) @@ -208,7 +164,14 @@ class TokenizerManager: self.rid_to_state[rid] = state while True: - await event.wait() + try: + await asyncio.wait_for(event.wait(), timeout=5) + except asyncio.TimeoutError: + if request is not None and await request.is_disconnected(): + self.abort_request(rid) + raise ValueError(f"Abort request {rid}") + continue + out = self.convert_logprob_style( state.out_list[-1], obj.return_logprob, @@ -226,7 +189,8 @@ class TokenizerManager: break event.clear() else: - assert obj.stream is False + if obj.stream: + raise ValueError("Do not support stream for batch mode.") if obj.input_ids is None: bs = len(obj.text) @@ -276,7 +240,18 @@ class TokenizerManager: for i in range(bs): rid = obj.rid[i] state = self.rid_to_state[rid] - await state.event.wait() + + while True: + try: + await asyncio.wait_for(state.event.wait(), timeout=5) + break + except asyncio.TimeoutError: + if request is not None and await request.is_disconnected(): + for rid in obj.rid: + self.abort_request(rid) + raise ValueError(f"Abort request {rid}") + continue + output_list.append( self.convert_logprob_style( state.out_list[-1], @@ -290,11 +265,16 @@ class TokenizerManager: yield output_list - async def flush_cache(self): - flush_cache_req = FlushCacheReq() - self.send_to_router.send_pyobj(flush_cache_req) + def flush_cache(self): + req = FlushCacheReq() + self.send_to_router.send_pyobj(req) - async def create_handle_loop(self): + def abort_request(self, rid): + del self.rid_to_state[rid] + req = AbortReq(rid) + self.send_to_router.send_pyobj(req) + + def create_handle_loop(self): self.to_create_loop = False loop = asyncio.get_event_loop() loop.create_task(self.handle_loop()) @@ -305,17 +285,20 @@ class TokenizerManager: if isinstance(recv_obj, BatchStrOut): for i, rid in enumerate(recv_obj.rids): + state = self.rid_to_state.get(rid, None) + if state is None: + continue + recv_obj.meta_info[i]["id"] = rid out_dict = { "text": recv_obj.output_str[i], "meta_info": recv_obj.meta_info[i], } - state = self.rid_to_state[rid] state.out_list.append(out_dict) state.finished = recv_obj.finished[i] state.event.set() else: - raise ValueError(f"Invalid object: {recv_obj}") + raise ValueError(f"Invalid object: {recv_obj}.") def convert_logprob_style( self, ret, return_logprob, top_logprobs_num, return_text_in_logprobs @@ -356,3 +339,50 @@ class TokenizerManager: if t: top_logprobs[i] = self.detokenize_logprob_tokens(t, decode_to_text) return top_logprobs + + + +global global_processor + + +def init_global_processor(server_args: ServerArgs): + global global_processor + transformers.logging.set_verbosity_error() + global_processor = get_processor( + server_args.tokenizer_path, + tokenizer_mode=server_args.tokenizer_mode, + trust_remote_code=server_args.trust_remote_code, + ) + + +def get_pixel_values( + image_data, image_aspect_ratio=None, image_grid_pinpoints=None, processor=None +): + try: + processor = processor or global_processor + image, image_size = load_image(image_data) + if image_size != None: + image_hash = hash(image_data) + pixel_values = processor.image_processor(image)["pixel_values"] + for _ in range(len(pixel_values)): + pixel_values[_] = pixel_values[_].astype(np.float16) + pixel_values = np.stack(pixel_values, axis=0) + return pixel_values, image_hash, image_size + else: + image_hash = hash(image_data) + if image_aspect_ratio == "pad": + image = expand2square( + image, + tuple(int(x * 255) for x in processor.image_processor.image_mean), + ) + pixel_values = processor.image_processor(image)["pixel_values"][0] + elif image_aspect_ratio == "anyres": + pixel_values = process_anyres_image( + image, processor.image_processor, image_grid_pinpoints + ) + else: + pixel_values = processor.image_processor(image)["pixel_values"][0] + pixel_values = pixel_values.astype(np.float16) + return pixel_values, image_hash, image.size + except Exception: + print("Exception in TokenizerManager:\n" + get_exception_traceback()) \ No newline at end of file diff --git a/python/sglang/srt/openai_api_adapter.py b/python/sglang/srt/openai_api_adapter.py index 13cd4ef08..b5aae388a 100644 --- a/python/sglang/srt/openai_api_adapter.py +++ b/python/sglang/srt/openai_api_adapter.py @@ -335,7 +335,7 @@ def to_openai_style_logprobs( ret_logprobs.tokens.append(token_text) ret_logprobs.token_logprobs.append(logprob) - # Not Supported yet + # Not supported yet ret_logprobs.text_offset.append(-1) def append_top_logprobs(top_logprobs): diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py index b416ac042..46e691283 100644 --- a/python/sglang/srt/server.py +++ b/python/sglang/srt/server.py @@ -10,6 +10,7 @@ import sys import threading import time from typing import List, Optional, Union +from http import HTTPStatus # Fix a bug of Python threading setattr(threading, "_register_atexit", lambda *args, **kwargs: None) @@ -73,7 +74,7 @@ async def get_server_args(): @app.get("/flush_cache") async def flush_cache(): - await tokenizer_manager.flush_cache() + tokenizer_manager.flush_cache() return Response( content="Cache flushed.\nPlease check backend logs for more details. " "(When there are running or waiting requests, the operation will not be performed.)\n", @@ -81,24 +82,25 @@ async def flush_cache(): ) -async def generate_request(obj: GenerateReqInput): - obj.post_init() - +async def generate_request(obj: GenerateReqInput, request: Request): if obj.stream: - async def stream_results(): - async for out in tokenizer_manager.generate_request(obj): + try: + async for out in tokenizer_manager.generate_request(obj, request): + yield f"data: {json.dumps(out, ensure_ascii=False)}\n\n" + except ValueError as e: + out = {"error": {"message": str(e)}} yield f"data: {json.dumps(out, ensure_ascii=False)}\n\n" yield "data: [DONE]\n\n" return StreamingResponse(stream_results(), media_type="text/event-stream") - - try: - ret = await tokenizer_manager.generate_request(obj).__anext__() - return ret - except ValueError as e: - print(f"Error: {e}") - return JSONResponse({"error": str(e)}, status_code=400) + else: + try: + ret = await tokenizer_manager.generate_request(obj, request).__anext__() + return ret + except ValueError as e: + return JSONResponse({"error": {"message": str(e)}}, + status_code=HTTPStatus.BAD_REQUEST) app.post("/generate")(generate_request) app.put("/generate")(generate_request) @@ -186,6 +188,7 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg if server_args.api_key and server_args.api_key != "": app.add_middleware(APIKeyValidatorMiddleware, api_key=server_args.api_key) + # Send a warmup request def _wait_and_warmup(): headers = {} url = server_args.url() @@ -228,6 +231,8 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg t = threading.Thread(target=_wait_and_warmup) t.start() + + # Listen for requests try: uvicorn.run( app, diff --git a/python/sglang/test/test_utils.py b/python/sglang/test/test_utils.py index 1b0d8fe6e..b2aaeafaa 100644 --- a/python/sglang/test/test_utils.py +++ b/python/sglang/test/test_utils.py @@ -9,7 +9,7 @@ import requests from sglang.backend.openai import OpenAI from sglang.backend.runtime_endpoint import RuntimeEndpoint from sglang.global_config import global_config -from sglang.srt.utils import get_exception_traceback +from sglang.utils import get_exception_traceback def call_generate_lightllm(prompt, temperature, max_tokens, stop=None, url=None): diff --git a/python/sglang/utils.py b/python/sglang/utils.py index 365ec16f4..d1fa241e9 100644 --- a/python/sglang/utils.py +++ b/python/sglang/utils.py @@ -93,8 +93,12 @@ def http_request( data = None else: data = bytes(dumps(json), encoding="utf-8") - resp = urllib.request.urlopen(req, data=data, cafile=verify) - return HttpResponse(resp) + + try: + resp = urllib.request.urlopen(req, data=data, cafile=verify) + return HttpResponse(resp) + except urllib.error.HTTPError as e: + return HttpResponse(e) def encode_image_base64(image_path):