From f6ab4ca6bc45896542dfa6e364b5683385cbccd2 Mon Sep 17 00:00:00 2001 From: mlmz <54172054+minleminzui@users.noreply.github.com> Date: Sat, 22 Mar 2025 10:11:15 +0800 Subject: [PATCH] fix: fix ipython running error for Engine due to outlines nest_asyncio (#4582) Co-authored-by: shuaills --- docs/backend/offline_engine_api.ipynb | 7 +----- docs/backend/patch.py | 4 ++++ python/sglang/srt/openai_api/adapter.py | 9 +------ python/sglang/utils.py | 31 +++++++++++++++++++++++++ 4 files changed, 37 insertions(+), 14 deletions(-) diff --git a/docs/backend/offline_engine_api.ipynb b/docs/backend/offline_engine_api.ipynb index f3de53b37..22ec2e590 100644 --- a/docs/backend/offline_engine_api.ipynb +++ b/docs/backend/offline_engine_api.ipynb @@ -19,12 +19,7 @@ "- Streaming asynchronous generation\n", "\n", "Additionally, you can easily build a custom server on top of the SGLang offline engine. A detailed example working in a python script can be found in [custom_server](https://github.com/sgl-project/sglang/blob/main/examples/runtime/engine/custom_server.py).\n", - "\n", - "## SPECIAL WARNING!!!!\n", - "\n", - "**To launch the offline engine in your python scripts,** `__main__` **condition is necessary, since we use** `spawn` **mode to create subprocesses. Please refer to this simple example**:\n", - "\n", - "https://github.com/sgl-project/sglang/blob/main/examples/runtime/engine/launch_engine.py" + "\n" ] }, { diff --git a/docs/backend/patch.py b/docs/backend/patch.py index d16422d08..83f52fd3f 100644 --- a/docs/backend/patch.py +++ b/docs/backend/patch.py @@ -1,6 +1,10 @@ import os import weakref +import nest_asyncio + +nest_asyncio.apply() + from sglang.utils import execute_shell_command, reserve_port DEFAULT_MAX_RUNNING_REQUESTS = 200 diff --git a/python/sglang/srt/openai_api/adapter.py b/python/sglang/srt/openai_api/adapter.py index d70930377..fe5520356 100644 --- a/python/sglang/srt/openai_api/adapter.py +++ b/python/sglang/srt/openai_api/adapter.py @@ -26,13 +26,6 @@ from fastapi import HTTPException, Request, UploadFile from fastapi.responses import ORJSONResponse, StreamingResponse from pydantic import ValidationError -try: - from outlines.fsm.json_schema import convert_json_schema_to_str -except ImportError: - # Before outlines 0.0.47, convert_json_schema_to_str is under - # outlines.integrations.utils - from outlines.integrations.utils import convert_json_schema_to_str - from sglang.srt.code_completion_parser import ( generate_completion_prompt_from_request, is_completion_template_defined, @@ -79,7 +72,7 @@ from sglang.srt.openai_api.protocol import ( UsageInfo, ) from sglang.srt.reasoning_parser import ReasoningParser -from sglang.utils import get_exception_traceback +from sglang.utils import convert_json_schema_to_str, get_exception_traceback logger = logging.getLogger(__name__) diff --git a/python/sglang/utils.py b/python/sglang/utils.py index 4a751aa88..bd5acd43d 100644 --- a/python/sglang/utils.py +++ b/python/sglang/utils.py @@ -22,6 +22,7 @@ from typing import Any, Callable, List, Optional, Tuple, Type, Union import numpy as np import requests from IPython.display import HTML, display +from pydantic import BaseModel from tqdm import tqdm from sglang.srt.utils import kill_process_tree @@ -29,6 +30,36 @@ from sglang.srt.utils import kill_process_tree logger = logging.getLogger(__name__) +def convert_json_schema_to_str(json_schema: Union[dict, str, Type[BaseModel]]) -> str: + """Convert a JSON schema to a string. + Parameters + ---------- + json_schema + The JSON schema. + Returns + ------- + str + The JSON schema converted to a string. + Raises + ------ + ValueError + If the schema is not a dictionary, a string or a Pydantic class. + """ + if isinstance(json_schema, dict): + schema_str = json.dumps(json_schema) + elif isinstance(json_schema, str): + schema_str = json_schema + elif issubclass(json_schema, BaseModel): + schema_str = json.dumps(json_schema.model_json_schema()) + else: + raise ValueError( + f"Cannot parse schema {json_schema}. The schema must be either " + + "a Pydantic class, a dictionary or a string that contains the JSON " + + "schema specification" + ) + return schema_str + + def get_exception_traceback(): etype, value, tb = sys.exc_info() err_str = "".join(traceback.format_exception(etype, value, tb))