[ perf ] Replace json-> orjson in hot path (#11221)

Signed-off-by: vincentzed <207368749+vincentzed@users.noreply.github.com>
This commit is contained in:
Vincent Zhong
2025-10-12 08:30:58 -04:00
committed by GitHub
parent 7b064f04f8
commit a220536f40
13 changed files with 32 additions and 20 deletions

View File

@@ -1,10 +1,11 @@
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py # Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py
import enum import enum
import json
import logging import logging
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import List, Optional, Union from typing import List, Optional, Union
import orjson
from sglang.srt.utils import is_hip from sglang.srt.utils import is_hip
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -66,7 +67,7 @@ class LoadConfig:
def __post_init__(self): def __post_init__(self):
model_loader_extra_config = self.model_loader_extra_config or {} model_loader_extra_config = self.model_loader_extra_config or {}
if isinstance(model_loader_extra_config, str): if isinstance(model_loader_extra_config, str):
self.model_loader_extra_config = json.loads(model_loader_extra_config) self.model_loader_extra_config = orjson.loads(model_loader_extra_config)
self._verify_load_format() self._verify_load_format()
if self.ignore_patterns is not None and len(self.ignore_patterns) > 0: if self.ignore_patterns is not None and len(self.ignore_patterns) > 0:

View File

@@ -5,6 +5,8 @@ import logging
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Union from typing import Union
import orjson
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
try: try:
@@ -148,7 +150,7 @@ class HarmonyContext(ConversationContext):
if isinstance(tool_session, Tool): if isinstance(tool_session, Tool):
return await tool_session.get_result(self) return await tool_session.get_result(self)
tool_name = last_msg.recipient.split(".")[1] tool_name = last_msg.recipient.split(".")[1]
args = json.loads(last_msg.content[0].text) args = orjson.loads(last_msg.content[0].text)
result = await tool_session.call_tool(tool_name, args) result = await tool_session.call_tool(tool_name, args)
result_str = result.content[0].text result_str = result.content[0].text
content = TextContent(text=result_str) content = TextContent(text=result_str)

View File

@@ -7,6 +7,7 @@ import json
from collections.abc import Iterable from collections.abc import Iterable
from typing import Literal, Optional, Union from typing import Literal, Optional, Union
import orjson
from openai.types.responses import ( from openai.types.responses import (
ResponseOutputItem, ResponseOutputItem,
ResponseOutputMessage, ResponseOutputMessage,
@@ -228,7 +229,7 @@ def parse_output_message(message: Message):
if len(message.content) != 1: if len(message.content) != 1:
raise ValueError("Invalid number of contents in browser message") raise ValueError("Invalid number of contents in browser message")
content = message.content[0] content = message.content[0]
browser_call = json.loads(content.text) browser_call = orjson.loads(content.text)
# TODO: translate to url properly! # TODO: translate to url properly!
if recipient == "browser.search": if recipient == "browser.search":
action = ActionSearch( action = ActionSearch(

View File

@@ -555,7 +555,7 @@ async def generate_request(obj: GenerateReqInput, request: Request):
async def generate_from_file_request(file: UploadFile, request: Request): async def generate_from_file_request(file: UploadFile, request: Request):
"""Handle a generate request, this is purely to work with input_embeds.""" """Handle a generate request, this is purely to work with input_embeds."""
content = await file.read() content = await file.read()
input_embeds = json.loads(content.decode("utf-8")) input_embeds = orjson.loads(content.decode("utf-8"))
obj = GenerateReqInput( obj = GenerateReqInput(
input_embeds=input_embeds, input_embeds=input_embeds,

View File

@@ -6,6 +6,7 @@ import uuid
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Optional, Union from typing import TYPE_CHECKING, Any, Optional, Union
import orjson
from fastapi import HTTPException, Request from fastapi import HTTPException, Request
from fastapi.responses import ORJSONResponse, StreamingResponse from fastapi.responses import ORJSONResponse, StreamingResponse
@@ -197,7 +198,7 @@ class OpenAIServingBase(ABC):
) )
try: try:
raw_labels = ( raw_labels = (
json.loads(raw_request.headers.get(header)) orjson.loads(raw_request.headers.get(header))
if raw_request and raw_request.headers.get(header) if raw_request and raw_request.headers.get(header)
else None else None
) )

View File

@@ -7,6 +7,7 @@ import time
import uuid import uuid
from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, List, Optional, Union from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, List, Optional, Union
import orjson
from fastapi import Request from fastapi import Request
from fastapi.responses import ORJSONResponse, StreamingResponse from fastapi.responses import ORJSONResponse, StreamingResponse
from jsonschema import Draft202012Validator, SchemaError from jsonschema import Draft202012Validator, SchemaError
@@ -285,7 +286,7 @@ class OpenAIServingChat(OpenAIServingBase):
if "arguments" in item["function"] and isinstance( if "arguments" in item["function"] and isinstance(
item["function"]["arguments"], str item["function"]["arguments"], str
): ):
item["function"]["arguments"] = json.loads( item["function"]["arguments"] = orjson.loads(
item["function"]["arguments"] item["function"]["arguments"]
) )
@@ -860,7 +861,7 @@ class OpenAIServingChat(OpenAIServingBase):
finish_reason["matched"] = None finish_reason["matched"] = None
try: try:
# For required tool choice, we expect a JSON array of tool calls # For required tool choice, we expect a JSON array of tool calls
tool_call_data = json.loads(text) tool_call_data = orjson.loads(text)
tool_calls = [] tool_calls = []
for i, tool in enumerate(tool_call_data): for i, tool in enumerate(tool_call_data):
# Create a ToolCallItem from the JSON data # Create a ToolCallItem from the JSON data

View File

@@ -5,7 +5,6 @@ from __future__ import annotations
import asyncio import asyncio
import copy import copy
import json
import logging import logging
import time import time
from contextlib import AsyncExitStack from contextlib import AsyncExitStack
@@ -14,6 +13,7 @@ from typing import TYPE_CHECKING, Any, AsyncGenerator, AsyncIterator, Optional,
import jinja2 import jinja2
import openai.types.responses as openai_responses_types import openai.types.responses as openai_responses_types
import orjson
from fastapi import Request from fastapi import Request
from fastapi.responses import ORJSONResponse from fastapi.responses import ORJSONResponse
from openai.types.responses import ( from openai.types.responses import (
@@ -1061,7 +1061,7 @@ class OpenAIServingResponses(OpenAIServingChat):
): ):
function_name = previous_item.recipient[len("browser.") :] function_name = previous_item.recipient[len("browser.") :]
action = None action = None
parsed_args = json.loads(previous_item.content[0].text) parsed_args = ororjson.loads(previous_item.content[0].text)
if function_name == "search": if function_name == "search":
action = openai_responses_types.response_function_web_search.ActionSearch( action = openai_responses_types.response_function_web_search.ActionSearch(
type="search", type="search",

View File

@@ -3,6 +3,7 @@ import logging
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, Dict, List from typing import Any, Dict, List
import orjson
from partial_json_parser.core.exceptions import MalformedJSON from partial_json_parser.core.exceptions import MalformedJSON
from partial_json_parser.core.options import Allow from partial_json_parser.core.options import Allow
@@ -96,7 +97,7 @@ class BaseFormatDetector(ABC):
Parses the text in one go. Returns success=True if the format matches, otherwise False. Parses the text in one go. Returns success=True if the format matches, otherwise False.
Note that leftover_text here represents "content that this parser will not consume further". Note that leftover_text here represents "content that this parser will not consume further".
""" """
action = json.loads(text) action = orjson.loads(text)
return StreamingParseResult(calls=self.parse_base_json(action, tools)) return StreamingParseResult(calls=self.parse_base_json(action, tools))
def _ends_with_partial_token(self, buffer: str, bot_token: str) -> int: def _ends_with_partial_token(self, buffer: str, bot_token: str) -> int:

View File

@@ -3,6 +3,7 @@ from json import JSONDecodeError, JSONDecoder
from json.decoder import WHITESPACE from json.decoder import WHITESPACE
from typing import Any, List, Literal, Optional, Tuple, Union from typing import Any, List, Literal, Optional, Tuple, Union
import orjson
import partial_json_parser import partial_json_parser
from partial_json_parser.core.options import Allow from partial_json_parser.core.options import Allow
@@ -51,7 +52,7 @@ def _partial_json_loads(input_str: str, flags: Allow) -> Tuple[Any, int]:
def _is_complete_json(input_str: str) -> bool: def _is_complete_json(input_str: str) -> bool:
try: try:
json.loads(input_str) orjson.loads(input_str)
return True return True
except JSONDecodeError: except JSONDecodeError:
return False return False

View File

@@ -34,6 +34,7 @@ from http import HTTPStatus
from typing import Any, Awaitable, Dict, List, Optional, Tuple, Union from typing import Any, Awaitable, Dict, List, Optional, Tuple, Union
import fastapi import fastapi
import orjson
import torch import torch
import uvloop import uvloop
import zmq import zmq
@@ -157,7 +158,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
self.log_requests = server_args.log_requests self.log_requests = server_args.log_requests
self.log_requests_level = server_args.log_requests_level self.log_requests_level = server_args.log_requests_level
self.preferred_sampling_params = ( self.preferred_sampling_params = (
json.loads(server_args.preferred_sampling_params) orjson.loads(server_args.preferred_sampling_params)
if server_args.preferred_sampling_params if server_args.preferred_sampling_params
else None else None
) )

View File

@@ -4,6 +4,7 @@ from functools import lru_cache
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
import dill import dill
import orjson
import torch import torch
@@ -12,7 +13,7 @@ def _cache_from_str(json_str: str):
"""Deserialize a json string to a Callable object. """Deserialize a json string to a Callable object.
This function is cached to avoid redundant deserialization. This function is cached to avoid redundant deserialization.
""" """
data = json.loads(json_str) data = orjson.loads(json_str)
return dill.loads(bytes.fromhex(data["callable"])) return dill.loads(bytes.fromhex(data["callable"]))

View File

@@ -22,6 +22,8 @@ import random
import tempfile import tempfile
from typing import Dict, List, Literal, Optional, Union from typing import Dict, List, Literal, Optional, Union
import orjson
from sglang.srt.connector import ConnectorType from sglang.srt.connector import ConnectorType
from sglang.srt.function_call.function_call_parser import FunctionCallParser from sglang.srt.function_call.function_call_parser import FunctionCallParser
from sglang.srt.lora.lora_registry import LoRARef from sglang.srt.lora.lora_registry import LoRARef
@@ -3041,7 +3043,7 @@ class ServerArgs:
self.model_path, self.model_path,
trust_remote_code=self.trust_remote_code, trust_remote_code=self.trust_remote_code,
revision=self.revision, revision=self.revision,
model_override_args=json.loads(self.json_model_override_args), model_override_args=orjson.loads(self.json_model_override_args),
**kwargs, **kwargs,
) )
return hf_config return hf_config

View File

@@ -12,7 +12,6 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Common utilities.""" """Common utilities."""
from __future__ import annotations from __future__ import annotations
import argparse import argparse
@@ -70,6 +69,7 @@ from typing import (
) )
import numpy as np import numpy as np
import orjson
import psutil import psutil
import pybase64 import pybase64
import requests import requests
@@ -1112,7 +1112,7 @@ def configure_logger(server_args, prefix: str = ""):
f"{SGLANG_LOGGING_CONFIG_PATH} but it does not exist!" f"{SGLANG_LOGGING_CONFIG_PATH} but it does not exist!"
) )
with open(SGLANG_LOGGING_CONFIG_PATH, encoding="utf-8") as file: with open(SGLANG_LOGGING_CONFIG_PATH, encoding="utf-8") as file:
custom_config = json.loads(file.read()) custom_config = orjson.loads(file.read())
logging.config.dictConfig(custom_config) logging.config.dictConfig(custom_config)
return return
format = f"[%(asctime)s{prefix}] %(message)s" format = f"[%(asctime)s{prefix}] %(message)s"
@@ -2525,9 +2525,9 @@ def log_info_on_rank0(logger, msg):
def load_json_config(data: str): def load_json_config(data: str):
try: try:
return json.loads(data) return orjson.loads(data)
except JSONDecodeError: except JSONDecodeError:
return json.loads(Path(data).read_text()) return orjson.loads(Path(data).read_text())
def dispose_tensor(x: torch.Tensor): def dispose_tensor(x: torch.Tensor):
@@ -3236,7 +3236,7 @@ def numa_bind_to_node(node: int):
def json_list_type(value): def json_list_type(value):
try: try:
return json.loads(value) return orjson.loads(value)
except json.JSONDecodeError: except json.JSONDecodeError:
raise argparse.ArgumentTypeError( raise argparse.ArgumentTypeError(
f"Invalid JSON list: {value}. Please provide a valid JSON list." f"Invalid JSON list: {value}. Please provide a valid JSON list."