[ perf ] Replace json-> orjson in hot path (#11221)
Signed-off-by: vincentzed <207368749+vincentzed@users.noreply.github.com>
This commit is contained in:
@@ -1,10 +1,11 @@
|
||||
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py
|
||||
import enum
|
||||
import json
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import orjson
|
||||
|
||||
from sglang.srt.utils import is_hip
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -66,7 +67,7 @@ class LoadConfig:
|
||||
def __post_init__(self):
|
||||
model_loader_extra_config = self.model_loader_extra_config or {}
|
||||
if isinstance(model_loader_extra_config, str):
|
||||
self.model_loader_extra_config = json.loads(model_loader_extra_config)
|
||||
self.model_loader_extra_config = orjson.loads(model_loader_extra_config)
|
||||
self._verify_load_format()
|
||||
|
||||
if self.ignore_patterns is not None and len(self.ignore_patterns) > 0:
|
||||
|
||||
@@ -5,6 +5,8 @@ import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Union
|
||||
|
||||
import orjson
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
try:
|
||||
@@ -148,7 +150,7 @@ class HarmonyContext(ConversationContext):
|
||||
if isinstance(tool_session, Tool):
|
||||
return await tool_session.get_result(self)
|
||||
tool_name = last_msg.recipient.split(".")[1]
|
||||
args = json.loads(last_msg.content[0].text)
|
||||
args = orjson.loads(last_msg.content[0].text)
|
||||
result = await tool_session.call_tool(tool_name, args)
|
||||
result_str = result.content[0].text
|
||||
content = TextContent(text=result_str)
|
||||
|
||||
@@ -7,6 +7,7 @@ import json
|
||||
from collections.abc import Iterable
|
||||
from typing import Literal, Optional, Union
|
||||
|
||||
import orjson
|
||||
from openai.types.responses import (
|
||||
ResponseOutputItem,
|
||||
ResponseOutputMessage,
|
||||
@@ -228,7 +229,7 @@ def parse_output_message(message: Message):
|
||||
if len(message.content) != 1:
|
||||
raise ValueError("Invalid number of contents in browser message")
|
||||
content = message.content[0]
|
||||
browser_call = json.loads(content.text)
|
||||
browser_call = orjson.loads(content.text)
|
||||
# TODO: translate to url properly!
|
||||
if recipient == "browser.search":
|
||||
action = ActionSearch(
|
||||
|
||||
@@ -555,7 +555,7 @@ async def generate_request(obj: GenerateReqInput, request: Request):
|
||||
async def generate_from_file_request(file: UploadFile, request: Request):
|
||||
"""Handle a generate request, this is purely to work with input_embeds."""
|
||||
content = await file.read()
|
||||
input_embeds = json.loads(content.decode("utf-8"))
|
||||
input_embeds = orjson.loads(content.decode("utf-8"))
|
||||
|
||||
obj = GenerateReqInput(
|
||||
input_embeds=input_embeds,
|
||||
|
||||
@@ -6,6 +6,7 @@ import uuid
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
|
||||
import orjson
|
||||
from fastapi import HTTPException, Request
|
||||
from fastapi.responses import ORJSONResponse, StreamingResponse
|
||||
|
||||
@@ -197,7 +198,7 @@ class OpenAIServingBase(ABC):
|
||||
)
|
||||
try:
|
||||
raw_labels = (
|
||||
json.loads(raw_request.headers.get(header))
|
||||
orjson.loads(raw_request.headers.get(header))
|
||||
if raw_request and raw_request.headers.get(header)
|
||||
else None
|
||||
)
|
||||
|
||||
@@ -7,6 +7,7 @@ import time
|
||||
import uuid
|
||||
from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, List, Optional, Union
|
||||
|
||||
import orjson
|
||||
from fastapi import Request
|
||||
from fastapi.responses import ORJSONResponse, StreamingResponse
|
||||
from jsonschema import Draft202012Validator, SchemaError
|
||||
@@ -285,7 +286,7 @@ class OpenAIServingChat(OpenAIServingBase):
|
||||
if "arguments" in item["function"] and isinstance(
|
||||
item["function"]["arguments"], str
|
||||
):
|
||||
item["function"]["arguments"] = json.loads(
|
||||
item["function"]["arguments"] = orjson.loads(
|
||||
item["function"]["arguments"]
|
||||
)
|
||||
|
||||
@@ -860,7 +861,7 @@ class OpenAIServingChat(OpenAIServingBase):
|
||||
finish_reason["matched"] = None
|
||||
try:
|
||||
# For required tool choice, we expect a JSON array of tool calls
|
||||
tool_call_data = json.loads(text)
|
||||
tool_call_data = orjson.loads(text)
|
||||
tool_calls = []
|
||||
for i, tool in enumerate(tool_call_data):
|
||||
# Create a ToolCallItem from the JSON data
|
||||
|
||||
@@ -5,7 +5,6 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import copy
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from contextlib import AsyncExitStack
|
||||
@@ -14,6 +13,7 @@ from typing import TYPE_CHECKING, Any, AsyncGenerator, AsyncIterator, Optional,
|
||||
|
||||
import jinja2
|
||||
import openai.types.responses as openai_responses_types
|
||||
import orjson
|
||||
from fastapi import Request
|
||||
from fastapi.responses import ORJSONResponse
|
||||
from openai.types.responses import (
|
||||
@@ -1061,7 +1061,7 @@ class OpenAIServingResponses(OpenAIServingChat):
|
||||
):
|
||||
function_name = previous_item.recipient[len("browser.") :]
|
||||
action = None
|
||||
parsed_args = json.loads(previous_item.content[0].text)
|
||||
parsed_args = ororjson.loads(previous_item.content[0].text)
|
||||
if function_name == "search":
|
||||
action = openai_responses_types.response_function_web_search.ActionSearch(
|
||||
type="search",
|
||||
|
||||
@@ -3,6 +3,7 @@ import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import orjson
|
||||
from partial_json_parser.core.exceptions import MalformedJSON
|
||||
from partial_json_parser.core.options import Allow
|
||||
|
||||
@@ -96,7 +97,7 @@ class BaseFormatDetector(ABC):
|
||||
Parses the text in one go. Returns success=True if the format matches, otherwise False.
|
||||
Note that leftover_text here represents "content that this parser will not consume further".
|
||||
"""
|
||||
action = json.loads(text)
|
||||
action = orjson.loads(text)
|
||||
return StreamingParseResult(calls=self.parse_base_json(action, tools))
|
||||
|
||||
def _ends_with_partial_token(self, buffer: str, bot_token: str) -> int:
|
||||
|
||||
@@ -3,6 +3,7 @@ from json import JSONDecodeError, JSONDecoder
|
||||
from json.decoder import WHITESPACE
|
||||
from typing import Any, List, Literal, Optional, Tuple, Union
|
||||
|
||||
import orjson
|
||||
import partial_json_parser
|
||||
from partial_json_parser.core.options import Allow
|
||||
|
||||
@@ -51,7 +52,7 @@ def _partial_json_loads(input_str: str, flags: Allow) -> Tuple[Any, int]:
|
||||
|
||||
def _is_complete_json(input_str: str) -> bool:
|
||||
try:
|
||||
json.loads(input_str)
|
||||
orjson.loads(input_str)
|
||||
return True
|
||||
except JSONDecodeError:
|
||||
return False
|
||||
|
||||
@@ -34,6 +34,7 @@ from http import HTTPStatus
|
||||
from typing import Any, Awaitable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import fastapi
|
||||
import orjson
|
||||
import torch
|
||||
import uvloop
|
||||
import zmq
|
||||
@@ -157,7 +158,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
|
||||
self.log_requests = server_args.log_requests
|
||||
self.log_requests_level = server_args.log_requests_level
|
||||
self.preferred_sampling_params = (
|
||||
json.loads(server_args.preferred_sampling_params)
|
||||
orjson.loads(server_args.preferred_sampling_params)
|
||||
if server_args.preferred_sampling_params
|
||||
else None
|
||||
)
|
||||
|
||||
@@ -4,6 +4,7 @@ from functools import lru_cache
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import dill
|
||||
import orjson
|
||||
import torch
|
||||
|
||||
|
||||
@@ -12,7 +13,7 @@ def _cache_from_str(json_str: str):
|
||||
"""Deserialize a json string to a Callable object.
|
||||
This function is cached to avoid redundant deserialization.
|
||||
"""
|
||||
data = json.loads(json_str)
|
||||
data = orjson.loads(json_str)
|
||||
return dill.loads(bytes.fromhex(data["callable"]))
|
||||
|
||||
|
||||
|
||||
@@ -22,6 +22,8 @@ import random
|
||||
import tempfile
|
||||
from typing import Dict, List, Literal, Optional, Union
|
||||
|
||||
import orjson
|
||||
|
||||
from sglang.srt.connector import ConnectorType
|
||||
from sglang.srt.function_call.function_call_parser import FunctionCallParser
|
||||
from sglang.srt.lora.lora_registry import LoRARef
|
||||
@@ -3041,7 +3043,7 @@ class ServerArgs:
|
||||
self.model_path,
|
||||
trust_remote_code=self.trust_remote_code,
|
||||
revision=self.revision,
|
||||
model_override_args=json.loads(self.json_model_override_args),
|
||||
model_override_args=orjson.loads(self.json_model_override_args),
|
||||
**kwargs,
|
||||
)
|
||||
return hf_config
|
||||
|
||||
@@ -12,7 +12,6 @@
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Common utilities."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
@@ -70,6 +69,7 @@ from typing import (
|
||||
)
|
||||
|
||||
import numpy as np
|
||||
import orjson
|
||||
import psutil
|
||||
import pybase64
|
||||
import requests
|
||||
@@ -1112,7 +1112,7 @@ def configure_logger(server_args, prefix: str = ""):
|
||||
f"{SGLANG_LOGGING_CONFIG_PATH} but it does not exist!"
|
||||
)
|
||||
with open(SGLANG_LOGGING_CONFIG_PATH, encoding="utf-8") as file:
|
||||
custom_config = json.loads(file.read())
|
||||
custom_config = orjson.loads(file.read())
|
||||
logging.config.dictConfig(custom_config)
|
||||
return
|
||||
format = f"[%(asctime)s{prefix}] %(message)s"
|
||||
@@ -2525,9 +2525,9 @@ def log_info_on_rank0(logger, msg):
|
||||
|
||||
def load_json_config(data: str):
|
||||
try:
|
||||
return json.loads(data)
|
||||
return orjson.loads(data)
|
||||
except JSONDecodeError:
|
||||
return json.loads(Path(data).read_text())
|
||||
return orjson.loads(Path(data).read_text())
|
||||
|
||||
|
||||
def dispose_tensor(x: torch.Tensor):
|
||||
@@ -3236,7 +3236,7 @@ def numa_bind_to_node(node: int):
|
||||
|
||||
def json_list_type(value):
|
||||
try:
|
||||
return json.loads(value)
|
||||
return orjson.loads(value)
|
||||
except json.JSONDecodeError:
|
||||
raise argparse.ArgumentTypeError(
|
||||
f"Invalid JSON list: {value}. Please provide a valid JSON list."
|
||||
|
||||
Reference in New Issue
Block a user