diff --git a/python/sglang/lang/interpreter.py b/python/sglang/lang/interpreter.py index 449417337..279dfea26 100644 --- a/python/sglang/lang/interpreter.py +++ b/python/sglang/lang/interpreter.py @@ -580,8 +580,8 @@ class StreamExecutor: def _execute_role_end(self, expr: SglRoleEnd): if ( self.cur_role == "assistant" - and self.backend.is_chat_model and self.api_num_spec_tokens is not None + and self.backend.is_chat_model ): # Execute the stored lazy generation calls self.backend.role_end_generate(self) diff --git a/python/sglang/srt/managers/router/infer_batch.py b/python/sglang/srt/managers/router/infer_batch.py index 61eb8e46d..dbe94371b 100644 --- a/python/sglang/srt/managers/router/infer_batch.py +++ b/python/sglang/srt/managers/router/infer_batch.py @@ -19,6 +19,7 @@ class FinishReason(IntEnum): EOS_TOKEN = auto() LENGTH = auto() STOP_STR = auto() + ABORT = auto() @staticmethod def to_str(reason): @@ -28,6 +29,8 @@ class FinishReason(IntEnum): return "length" elif reason == FinishReason.STOP_STR: return "stop" + elif reason == FinishReason.ABORT: + return "abort" else: return None @@ -86,6 +89,35 @@ class Req: def max_new_tokens(self): return self.sampling_params.max_new_tokens + def check_finished(self): + if self.finished: + return + + if len(self.output_ids) >= self.sampling_params.max_new_tokens: + self.finished = True + self.finish_reason = FinishReason.LENGTH + return + + if ( + self.output_ids[-1] == self.tokenizer.eos_token_id + and self.sampling_params.ignore_eos == False + ): + self.finished = True + self.finish_reason = FinishReason.EOS_TOKEN + return + + if len(self.sampling_params.stop_strs) > 0: + tail_str = self.tokenizer.decode( + self.output_ids[-(self.sampling_params.stop_str_max_len + 1) :] + ) + + for stop_str in self.sampling_params.stop_strs: + if stop_str in tail_str: + self.finished = True + self.finish_reason = FinishReason.STOP_STR + self.hit_stop_str = stop_str + return + def jump_forward_and_retokenize(self, jump_forward_str, next_state): old_output_str = self.tokenizer.decode(self.output_ids) # FIXME: This logic does not really solve the problem of determining whether @@ -132,35 +164,6 @@ class Req: # print(f"Output and jump forward str:\n{self.output_and_jump_forward_str}") # print("*" * 100) - def check_finished(self): - if self.finished: - return - - if len(self.output_ids) >= self.sampling_params.max_new_tokens: - self.finished = True - self.finish_reason = FinishReason.LENGTH - return - - if ( - self.output_ids[-1] == self.tokenizer.eos_token_id - and self.sampling_params.ignore_eos == False - ): - self.finished = True - self.finish_reason = FinishReason.EOS_TOKEN - return - - if len(self.sampling_params.stop_strs) > 0: - tail_str = self.tokenizer.decode( - self.output_ids[-(self.sampling_params.stop_str_max_len + 1) :] - ) - - for stop_str in self.sampling_params.stop_strs: - if stop_str in tail_str: - self.finished = True - self.finish_reason = FinishReason.STOP_STR - self.hit_stop_str = stop_str - return - def __repr__(self): return f"rid(n={self.rid}, " f"input_ids={self.input_ids}, " diff --git a/python/sglang/srt/managers/router/model_rpc.py b/python/sglang/srt/managers/router/model_rpc.py index 7f08313a5..fcfdd0cb0 100644 --- a/python/sglang/srt/managers/router/model_rpc.py +++ b/python/sglang/srt/managers/router/model_rpc.py @@ -679,6 +679,7 @@ class ModelRpcServer: ) def abort_request(self, recv_req): + # Delete requests in the waiting queue to_del = None for i, req in enumerate(self.forward_queue): if req.rid == recv_req.rid: @@ -688,6 +689,14 @@ class ModelRpcServer: if to_del is not None: del self.forward_queue[to_del] + # Delete requests in the running batch + if self.running_batch: + for req in self.running_batch.reqs: + if req.rid == recv_req.rid: + req.finished = True + req.finish_reason = FinishReason.ABORT + break + class ModelRpcService(rpyc.Service): exposed_ModelRpcServer = ModelRpcServer diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 074b11c00..482347153 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -11,6 +11,7 @@ import transformers import uvloop import zmq import zmq.asyncio +from fastapi import BackgroundTasks from sglang.srt.hf_transformers_utils import ( get_config, @@ -165,7 +166,7 @@ class TokenizerManager: while True: try: - await asyncio.wait_for(event.wait(), timeout=5) + await asyncio.wait_for(event.wait(), timeout=4) except asyncio.TimeoutError: if request is not None and await request.is_disconnected(): self.abort_request(rid) @@ -243,7 +244,7 @@ class TokenizerManager: while True: try: - await asyncio.wait_for(state.event.wait(), timeout=5) + await asyncio.wait_for(state.event.wait(), timeout=4) break except asyncio.TimeoutError: if request is not None and await request.is_disconnected(): @@ -270,10 +271,26 @@ class TokenizerManager: self.send_to_router.send_pyobj(req) def abort_request(self, rid): + if rid not in self.rid_to_state: + return del self.rid_to_state[rid] req = AbortReq(rid) self.send_to_router.send_pyobj(req) + def create_abort_task(self, obj): + # Abort the request if the client is disconnected. + async def abort_request(): + await asyncio.sleep(3) + if obj.is_single: + self.abort_request(obj.rid) + else: + for rid in obj.rids: + self.abort_request(rid) + + background_tasks = BackgroundTasks() + background_tasks.add_task(abort_request) + return background_tasks + def create_handle_loop(self): self.to_create_loop = False loop = asyncio.get_event_loop() diff --git a/python/sglang/srt/openai_api_adapter.py b/python/sglang/srt/openai_api_adapter.py index b5aae388a..1230dc07c 100644 --- a/python/sglang/srt/openai_api_adapter.py +++ b/python/sglang/srt/openai_api_adapter.py @@ -1,10 +1,12 @@ """Conversion between OpenAI APIs and native SRT APIs""" +import asyncio import json import os +from http import HTTPStatus -from fastapi import HTTPException, Request -from fastapi.responses import StreamingResponse +from fastapi import Request +from fastapi.responses import StreamingResponse, JSONResponse from sglang.srt.conversation import ( Conversation, @@ -27,14 +29,36 @@ from sglang.srt.openai_protocol import ( CompletionResponseStreamChoice, CompletionStreamResponse, DeltaMessage, + ErrorResponse, LogProbs, UsageInfo, ) -from sglang.srt.utils import jsonify_pydantic_model chat_template_name = None +def create_error_response( + message: str, + err_type: str = "BadRequestError", + status_code: HTTPStatus = HTTPStatus.BAD_REQUEST): + error = ErrorResponse(message=message, + type=err_type, + code=status_code.value) + return JSONResponse(content=error.model_dump(), + status_code=error.code) + + +def create_streaming_error_response( + message: str, + err_type: str = "BadRequestError", + status_code: HTTPStatus = HTTPStatus.BAD_REQUEST) -> str: + error = ErrorResponse(message=message, + type=err_type, + code=status_code.value) + json_str = json.dumps({"error": error.model_dump()}) + return json_str + + def load_chat_template_for_openai_api(chat_template_arg): global chat_template_name @@ -74,8 +98,8 @@ async def v1_completions(tokenizer_manager, raw_request: Request): request_json = await raw_request.json() request = CompletionRequest(**request_json) - # TODO: Validate the request and return HTTPStatus.BAD_REQUEST if invalid. - assert request.n == 1 + if request.n != 1: + return create_error_response("n != 1 is not supported") adapted_request = GenerateReqInput( text=request.prompt, @@ -93,79 +117,88 @@ async def v1_completions(tokenizer_manager, raw_request: Request): return_text_in_logprobs=True, stream=request.stream, ) - adapted_request.post_init() if adapted_request.stream: async def generate_stream_resp(): stream_buffer = "" n_prev_token = 0 - async for content in tokenizer_manager.generate_request(adapted_request): - text = content["text"] - prompt_tokens = content["meta_info"]["prompt_tokens"] - completion_tokens = content["meta_info"]["completion_tokens"] + try: + async for content in tokenizer_manager.generate_request( + adapted_request, raw_request): + text = content["text"] + prompt_tokens = content["meta_info"]["prompt_tokens"] + completion_tokens = content["meta_info"]["completion_tokens"] - if not stream_buffer: # The first chunk - if request.echo: - # Prepend prompt in response text. - text = request.prompt + text + if not stream_buffer: # The first chunk + if request.echo: + # Prepend prompt in response text. + text = request.prompt + text - if request.logprobs: - # The first chunk and echo is enabled. - if not stream_buffer and request.echo: - prefill_token_logprobs = content["meta_info"][ - "prefill_token_logprobs" - ] - prefill_top_logprobs = content["meta_info"][ - "prefill_top_logprobs" - ] + if request.logprobs: + # The first chunk and echo is enabled. + if not stream_buffer and request.echo: + prefill_token_logprobs = content["meta_info"][ + "prefill_token_logprobs" + ] + prefill_top_logprobs = content["meta_info"][ + "prefill_top_logprobs" + ] + else: + prefill_token_logprobs = None + prefill_top_logprobs = None + + logprobs = to_openai_style_logprobs( + prefill_token_logprobs=prefill_token_logprobs, + prefill_top_logprobs=prefill_top_logprobs, + decode_token_logprobs=content["meta_info"][ + "decode_token_logprobs" + ][n_prev_token:], + decode_top_logprobs=content["meta_info"]["decode_top_logprobs"][ + n_prev_token: + ], + ) + + n_prev_token = len(content["meta_info"]["decode_token_logprobs"]) else: - prefill_token_logprobs = None - prefill_top_logprobs = None + logprobs = None - logprobs = to_openai_style_logprobs( - prefill_token_logprobs=prefill_token_logprobs, - prefill_top_logprobs=prefill_top_logprobs, - decode_token_logprobs=content["meta_info"][ - "decode_token_logprobs" - ][n_prev_token:], - decode_top_logprobs=content["meta_info"]["decode_top_logprobs"][ - n_prev_token: - ], + delta = text[len(stream_buffer) :] + stream_buffer = content["text"] + choice_data = CompletionResponseStreamChoice( + index=0, + text=delta, + logprobs=logprobs, + finish_reason=content["meta_info"]["finish_reason"], ) - - n_prev_token = len(content["meta_info"]["decode_token_logprobs"]) - else: - logprobs = None - - delta = text[len(stream_buffer) :] - stream_buffer = content["text"] - choice_data = CompletionResponseStreamChoice( - index=0, - text=delta, - logprobs=logprobs, - finish_reason=content["meta_info"]["finish_reason"], - ) - chunk = CompletionStreamResponse( - id=content["meta_info"]["id"], - object="text_completion", - choices=[choice_data], - model=request.model, - usage=UsageInfo( - prompt_tokens=prompt_tokens, - completion_tokens=completion_tokens, - total_tokens=prompt_tokens + completion_tokens, - ), - ) - yield f"data: {jsonify_pydantic_model(chunk)}\n\n" + chunk = CompletionStreamResponse( + id=content["meta_info"]["id"], + object="text_completion", + choices=[choice_data], + model=request.model, + usage=UsageInfo( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=prompt_tokens + completion_tokens, + ), + ) + yield f"data: {chunk.model_dump_json()}\n\n" + except ValueError as e: + error = create_streaming_error_response(str(e)) + yield f"data: {error}\n\n" yield "data: [DONE]\n\n" - return StreamingResponse(generate_stream_resp(), media_type="text/event-stream") + return StreamingResponse(generate_stream_resp(), media_type="text/event-stream", + background=tokenizer_manager.create_abort_task(adapted_request)) # Non-streaming response. - ret = await tokenizer_manager.generate_request(adapted_request).__anext__() - ret = ret[0] if isinstance(ret, list) else ret + try: + ret = await tokenizer_manager.generate_request( + adapted_request, raw_request).__anext__() + except ValueError as e: + return create_error_response(str(e)) + ret = ret[0] if isinstance(ret, list) else ret prompt_tokens = ret["meta_info"]["prompt_tokens"] completion_tokens = ret["meta_info"]["completion_tokens"] text = ret["text"] @@ -212,8 +245,8 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request): request_json = await raw_request.json() request = ChatCompletionRequest(**request_json) - # TODO: Validate the request and return HTTPStatus.BAD_REQUEST if invalid. - assert request.n == 1 + if request.n != 1: + return create_error_response("n != 1 is not supported") # Prep the data needed for the underlying GenerateReqInput: # - prompt: The full prompt string. @@ -258,7 +291,6 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request): }, stream=request.stream, ) - adapted_request.post_init() if adapted_request.stream: @@ -266,13 +298,29 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request): is_first = True stream_buffer = "" - async for content in tokenizer_manager.generate_request(adapted_request): - if is_first: - # First chunk with role - is_first = False + try: + async for content in tokenizer_manager.generate_request(adapted_request, raw_request): + if is_first: + # First chunk with role + is_first = False + choice_data = ChatCompletionResponseStreamChoice( + index=0, + delta=DeltaMessage(role="assistant"), + finish_reason=content["meta_info"]["finish_reason"], + ) + chunk = ChatCompletionStreamResponse( + id=content["meta_info"]["id"], + choices=[choice_data], + model=request.model, + ) + yield f"data: {chunk.model_dump_json()}\n\n" + + text = content["text"] + delta = text[len(stream_buffer) :] + stream_buffer = text choice_data = ChatCompletionResponseStreamChoice( index=0, - delta=DeltaMessage(role="assistant"), + delta=DeltaMessage(content=delta), finish_reason=content["meta_info"]["finish_reason"], ) chunk = ChatCompletionStreamResponse( @@ -280,28 +328,22 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request): choices=[choice_data], model=request.model, ) - yield f"data: {jsonify_pydantic_model(chunk)}\n\n" - - text = content["text"] - delta = text[len(stream_buffer) :] - stream_buffer = text - choice_data = ChatCompletionResponseStreamChoice( - index=0, - delta=DeltaMessage(content=delta), - finish_reason=content["meta_info"]["finish_reason"], - ) - chunk = ChatCompletionStreamResponse( - id=content["meta_info"]["id"], - choices=[choice_data], - model=request.model, - ) - yield f"data: {jsonify_pydantic_model(chunk)}\n\n" + yield f"data: {chunk.model_dump_json()}\n\n" + except ValueError as e: + error = create_streaming_error_response(str(e)) + yield f"data: {error}\n\n" yield "data: [DONE]\n\n" - return StreamingResponse(generate_stream_resp(), media_type="text/event-stream") + return StreamingResponse(generate_stream_resp(), media_type="text/event-stream", + background=tokenizer_manager.create_abort_task(adapted_request)) # Non-streaming response. - ret = await tokenizer_manager.generate_request(adapted_request).__anext__() + try: + ret = await tokenizer_manager.generate_request( + adapted_request, raw_request).__anext__() + except ValueError as e: + return create_error_response(str(e)) + prompt_tokens = ret["meta_info"]["prompt_tokens"] completion_tokens = ret["meta_info"]["completion_tokens"] choice_data = ChatCompletionResponseChoice( diff --git a/python/sglang/srt/openai_protocol.py b/python/sglang/srt/openai_protocol.py index ac88b2dd5..79c69ebdb 100644 --- a/python/sglang/srt/openai_protocol.py +++ b/python/sglang/srt/openai_protocol.py @@ -7,6 +7,14 @@ from pydantic import BaseModel, Field from typing_extensions import Literal +class ErrorResponse(BaseModel): + object: str = "error" + message: str + type: str + param: Optional[str] = None + code: int + + class LogProbs(BaseModel): text_offset: List[int] = Field(default_factory=list) token_logprobs: List[Optional[float]] = Field(default_factory=list) diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py index 54bbbfc3d..a94359707 100644 --- a/python/sglang/srt/server.py +++ b/python/sglang/srt/server.py @@ -93,7 +93,8 @@ async def generate_request(obj: GenerateReqInput, request: Request): yield f"data: {json.dumps(out, ensure_ascii=False)}\n\n" yield "data: [DONE]\n\n" - return StreamingResponse(stream_results(), media_type="text/event-stream") + return StreamingResponse(stream_results(), media_type="text/event-stream", + background=tokenizer_manager.create_abort_task(obj)) else: try: ret = await tokenizer_manager.generate_request(obj, request).__anext__() diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index 950915e9f..f8187ad2a 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -392,14 +392,4 @@ class APIKeyValidatorMiddleware(BaseHTTPMiddleware): content={"detail": "Invalid API Key"}, ) response = await call_next(request) - return response - - -# FIXME: Remove this once we drop support for pydantic 1.x -IS_PYDANTIC_1 = int(pydantic.VERSION.split(".")[0]) == 1 - - -def jsonify_pydantic_model(obj: BaseModel): - if IS_PYDANTIC_1: - return obj.json(ensure_ascii=False) - return obj.model_dump_json() + return response \ No newline at end of file