diff --git a/python/sglang/srt/configs/load_config.py b/python/sglang/srt/configs/load_config.py index fb8be846b..7059fd95a 100644 --- a/python/sglang/srt/configs/load_config.py +++ b/python/sglang/srt/configs/load_config.py @@ -1,10 +1,11 @@ # Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py import enum -import json import logging from dataclasses import dataclass, field from typing import List, Optional, Union +import orjson + from sglang.srt.utils import is_hip logger = logging.getLogger(__name__) @@ -66,7 +67,7 @@ class LoadConfig: def __post_init__(self): model_loader_extra_config = self.model_loader_extra_config or {} if isinstance(model_loader_extra_config, str): - self.model_loader_extra_config = json.loads(model_loader_extra_config) + self.model_loader_extra_config = orjson.loads(model_loader_extra_config) self._verify_load_format() if self.ignore_patterns is not None and len(self.ignore_patterns) > 0: diff --git a/python/sglang/srt/entrypoints/context.py b/python/sglang/srt/entrypoints/context.py index 66f58200f..9314083b4 100644 --- a/python/sglang/srt/entrypoints/context.py +++ b/python/sglang/srt/entrypoints/context.py @@ -5,6 +5,8 @@ import logging from abc import ABC, abstractmethod from typing import Union +import orjson + logger = logging.getLogger(__name__) try: @@ -148,7 +150,7 @@ class HarmonyContext(ConversationContext): if isinstance(tool_session, Tool): return await tool_session.get_result(self) tool_name = last_msg.recipient.split(".")[1] - args = json.loads(last_msg.content[0].text) + args = orjson.loads(last_msg.content[0].text) result = await tool_session.call_tool(tool_name, args) result_str = result.content[0].text content = TextContent(text=result_str) diff --git a/python/sglang/srt/entrypoints/harmony_utils.py b/python/sglang/srt/entrypoints/harmony_utils.py index 5ebb653b3..ad6350d16 100644 --- a/python/sglang/srt/entrypoints/harmony_utils.py +++ b/python/sglang/srt/entrypoints/harmony_utils.py @@ -7,6 +7,7 @@ import json from collections.abc import Iterable from typing import Literal, Optional, Union +import orjson from openai.types.responses import ( ResponseOutputItem, ResponseOutputMessage, @@ -228,7 +229,7 @@ def parse_output_message(message: Message): if len(message.content) != 1: raise ValueError("Invalid number of contents in browser message") content = message.content[0] - browser_call = json.loads(content.text) + browser_call = orjson.loads(content.text) # TODO: translate to url properly! if recipient == "browser.search": action = ActionSearch( diff --git a/python/sglang/srt/entrypoints/http_server.py b/python/sglang/srt/entrypoints/http_server.py index c64e309c4..4da8e880e 100644 --- a/python/sglang/srt/entrypoints/http_server.py +++ b/python/sglang/srt/entrypoints/http_server.py @@ -555,7 +555,7 @@ async def generate_request(obj: GenerateReqInput, request: Request): async def generate_from_file_request(file: UploadFile, request: Request): """Handle a generate request, this is purely to work with input_embeds.""" content = await file.read() - input_embeds = json.loads(content.decode("utf-8")) + input_embeds = orjson.loads(content.decode("utf-8")) obj = GenerateReqInput( input_embeds=input_embeds, diff --git a/python/sglang/srt/entrypoints/openai/serving_base.py b/python/sglang/srt/entrypoints/openai/serving_base.py index 2e027fd48..d42a942f3 100644 --- a/python/sglang/srt/entrypoints/openai/serving_base.py +++ b/python/sglang/srt/entrypoints/openai/serving_base.py @@ -6,6 +6,7 @@ import uuid from abc import ABC, abstractmethod from typing import TYPE_CHECKING, Any, Optional, Union +import orjson from fastapi import HTTPException, Request from fastapi.responses import ORJSONResponse, StreamingResponse @@ -197,7 +198,7 @@ class OpenAIServingBase(ABC): ) try: raw_labels = ( - json.loads(raw_request.headers.get(header)) + orjson.loads(raw_request.headers.get(header)) if raw_request and raw_request.headers.get(header) else None ) diff --git a/python/sglang/srt/entrypoints/openai/serving_chat.py b/python/sglang/srt/entrypoints/openai/serving_chat.py index 08a2bf20d..719fa2814 100644 --- a/python/sglang/srt/entrypoints/openai/serving_chat.py +++ b/python/sglang/srt/entrypoints/openai/serving_chat.py @@ -7,6 +7,7 @@ import time import uuid from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, List, Optional, Union +import orjson from fastapi import Request from fastapi.responses import ORJSONResponse, StreamingResponse from jsonschema import Draft202012Validator, SchemaError @@ -285,7 +286,7 @@ class OpenAIServingChat(OpenAIServingBase): if "arguments" in item["function"] and isinstance( item["function"]["arguments"], str ): - item["function"]["arguments"] = json.loads( + item["function"]["arguments"] = orjson.loads( item["function"]["arguments"] ) @@ -860,7 +861,7 @@ class OpenAIServingChat(OpenAIServingBase): finish_reason["matched"] = None try: # For required tool choice, we expect a JSON array of tool calls - tool_call_data = json.loads(text) + tool_call_data = orjson.loads(text) tool_calls = [] for i, tool in enumerate(tool_call_data): # Create a ToolCallItem from the JSON data diff --git a/python/sglang/srt/entrypoints/openai/serving_responses.py b/python/sglang/srt/entrypoints/openai/serving_responses.py index 958aee933..87b1f2b6b 100644 --- a/python/sglang/srt/entrypoints/openai/serving_responses.py +++ b/python/sglang/srt/entrypoints/openai/serving_responses.py @@ -5,7 +5,6 @@ from __future__ import annotations import asyncio import copy -import json import logging import time from contextlib import AsyncExitStack @@ -14,6 +13,7 @@ from typing import TYPE_CHECKING, Any, AsyncGenerator, AsyncIterator, Optional, import jinja2 import openai.types.responses as openai_responses_types +import orjson from fastapi import Request from fastapi.responses import ORJSONResponse from openai.types.responses import ( @@ -1061,7 +1061,7 @@ class OpenAIServingResponses(OpenAIServingChat): ): function_name = previous_item.recipient[len("browser.") :] action = None - parsed_args = json.loads(previous_item.content[0].text) + parsed_args = ororjson.loads(previous_item.content[0].text) if function_name == "search": action = openai_responses_types.response_function_web_search.ActionSearch( type="search", diff --git a/python/sglang/srt/function_call/base_format_detector.py b/python/sglang/srt/function_call/base_format_detector.py index efc001d31..02a75c389 100644 --- a/python/sglang/srt/function_call/base_format_detector.py +++ b/python/sglang/srt/function_call/base_format_detector.py @@ -3,6 +3,7 @@ import logging from abc import ABC, abstractmethod from typing import Any, Dict, List +import orjson from partial_json_parser.core.exceptions import MalformedJSON from partial_json_parser.core.options import Allow @@ -96,7 +97,7 @@ class BaseFormatDetector(ABC): Parses the text in one go. Returns success=True if the format matches, otherwise False. Note that leftover_text here represents "content that this parser will not consume further". """ - action = json.loads(text) + action = orjson.loads(text) return StreamingParseResult(calls=self.parse_base_json(action, tools)) def _ends_with_partial_token(self, buffer: str, bot_token: str) -> int: diff --git a/python/sglang/srt/function_call/utils.py b/python/sglang/srt/function_call/utils.py index 898e13b13..5ad3f6e89 100644 --- a/python/sglang/srt/function_call/utils.py +++ b/python/sglang/srt/function_call/utils.py @@ -3,6 +3,7 @@ from json import JSONDecodeError, JSONDecoder from json.decoder import WHITESPACE from typing import Any, List, Literal, Optional, Tuple, Union +import orjson import partial_json_parser from partial_json_parser.core.options import Allow @@ -51,7 +52,7 @@ def _partial_json_loads(input_str: str, flags: Allow) -> Tuple[Any, int]: def _is_complete_json(input_str: str) -> bool: try: - json.loads(input_str) + orjson.loads(input_str) return True except JSONDecodeError: return False diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index be9e5699a..c034c37b9 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -34,6 +34,7 @@ from http import HTTPStatus from typing import Any, Awaitable, Dict, List, Optional, Tuple, Union import fastapi +import orjson import torch import uvloop import zmq @@ -157,7 +158,7 @@ class TokenizerManager(TokenizerCommunicatorMixin): self.log_requests = server_args.log_requests self.log_requests_level = server_args.log_requests_level self.preferred_sampling_params = ( - json.loads(server_args.preferred_sampling_params) + orjson.loads(server_args.preferred_sampling_params) if server_args.preferred_sampling_params else None ) diff --git a/python/sglang/srt/sampling/custom_logit_processor.py b/python/sglang/srt/sampling/custom_logit_processor.py index 67514819c..80820c361 100644 --- a/python/sglang/srt/sampling/custom_logit_processor.py +++ b/python/sglang/srt/sampling/custom_logit_processor.py @@ -4,6 +4,7 @@ from functools import lru_cache from typing import Any, Dict, List, Optional import dill +import orjson import torch @@ -12,7 +13,7 @@ def _cache_from_str(json_str: str): """Deserialize a json string to a Callable object. This function is cached to avoid redundant deserialization. """ - data = json.loads(json_str) + data = orjson.loads(json_str) return dill.loads(bytes.fromhex(data["callable"])) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 29d5cc03e..a73758d52 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -22,6 +22,8 @@ import random import tempfile from typing import Dict, List, Literal, Optional, Union +import orjson + from sglang.srt.connector import ConnectorType from sglang.srt.function_call.function_call_parser import FunctionCallParser from sglang.srt.lora.lora_registry import LoRARef @@ -3041,7 +3043,7 @@ class ServerArgs: self.model_path, trust_remote_code=self.trust_remote_code, revision=self.revision, - model_override_args=json.loads(self.json_model_override_args), + model_override_args=orjson.loads(self.json_model_override_args), **kwargs, ) return hf_config diff --git a/python/sglang/srt/utils/common.py b/python/sglang/srt/utils/common.py index 67fc8f608..b65c311f9 100644 --- a/python/sglang/srt/utils/common.py +++ b/python/sglang/srt/utils/common.py @@ -12,7 +12,6 @@ # limitations under the License. # ============================================================================== """Common utilities.""" - from __future__ import annotations import argparse @@ -70,6 +69,7 @@ from typing import ( ) import numpy as np +import orjson import psutil import pybase64 import requests @@ -1112,7 +1112,7 @@ def configure_logger(server_args, prefix: str = ""): f"{SGLANG_LOGGING_CONFIG_PATH} but it does not exist!" ) with open(SGLANG_LOGGING_CONFIG_PATH, encoding="utf-8") as file: - custom_config = json.loads(file.read()) + custom_config = orjson.loads(file.read()) logging.config.dictConfig(custom_config) return format = f"[%(asctime)s{prefix}] %(message)s" @@ -2525,9 +2525,9 @@ def log_info_on_rank0(logger, msg): def load_json_config(data: str): try: - return json.loads(data) + return orjson.loads(data) except JSONDecodeError: - return json.loads(Path(data).read_text()) + return orjson.loads(Path(data).read_text()) def dispose_tensor(x: torch.Tensor): @@ -3236,7 +3236,7 @@ def numa_bind_to_node(node: int): def json_list_type(value): try: - return json.loads(value) + return orjson.loads(value) except json.JSONDecodeError: raise argparse.ArgumentTypeError( f"Invalid JSON list: {value}. Please provide a valid JSON list."