[ perf ] Replace json-> orjson in hot path (#11221)
Signed-off-by: vincentzed <207368749+vincentzed@users.noreply.github.com>
This commit is contained in:
@@ -1,10 +1,11 @@
|
|||||||
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py
|
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py
|
||||||
import enum
|
import enum
|
||||||
import json
|
|
||||||
import logging
|
import logging
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import List, Optional, Union
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
|
import orjson
|
||||||
|
|
||||||
from sglang.srt.utils import is_hip
|
from sglang.srt.utils import is_hip
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -66,7 +67,7 @@ class LoadConfig:
|
|||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
model_loader_extra_config = self.model_loader_extra_config or {}
|
model_loader_extra_config = self.model_loader_extra_config or {}
|
||||||
if isinstance(model_loader_extra_config, str):
|
if isinstance(model_loader_extra_config, str):
|
||||||
self.model_loader_extra_config = json.loads(model_loader_extra_config)
|
self.model_loader_extra_config = orjson.loads(model_loader_extra_config)
|
||||||
self._verify_load_format()
|
self._verify_load_format()
|
||||||
|
|
||||||
if self.ignore_patterns is not None and len(self.ignore_patterns) > 0:
|
if self.ignore_patterns is not None and len(self.ignore_patterns) > 0:
|
||||||
|
|||||||
@@ -5,6 +5,8 @@ import logging
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
|
import orjson
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -148,7 +150,7 @@ class HarmonyContext(ConversationContext):
|
|||||||
if isinstance(tool_session, Tool):
|
if isinstance(tool_session, Tool):
|
||||||
return await tool_session.get_result(self)
|
return await tool_session.get_result(self)
|
||||||
tool_name = last_msg.recipient.split(".")[1]
|
tool_name = last_msg.recipient.split(".")[1]
|
||||||
args = json.loads(last_msg.content[0].text)
|
args = orjson.loads(last_msg.content[0].text)
|
||||||
result = await tool_session.call_tool(tool_name, args)
|
result = await tool_session.call_tool(tool_name, args)
|
||||||
result_str = result.content[0].text
|
result_str = result.content[0].text
|
||||||
content = TextContent(text=result_str)
|
content = TextContent(text=result_str)
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import json
|
|||||||
from collections.abc import Iterable
|
from collections.abc import Iterable
|
||||||
from typing import Literal, Optional, Union
|
from typing import Literal, Optional, Union
|
||||||
|
|
||||||
|
import orjson
|
||||||
from openai.types.responses import (
|
from openai.types.responses import (
|
||||||
ResponseOutputItem,
|
ResponseOutputItem,
|
||||||
ResponseOutputMessage,
|
ResponseOutputMessage,
|
||||||
@@ -228,7 +229,7 @@ def parse_output_message(message: Message):
|
|||||||
if len(message.content) != 1:
|
if len(message.content) != 1:
|
||||||
raise ValueError("Invalid number of contents in browser message")
|
raise ValueError("Invalid number of contents in browser message")
|
||||||
content = message.content[0]
|
content = message.content[0]
|
||||||
browser_call = json.loads(content.text)
|
browser_call = orjson.loads(content.text)
|
||||||
# TODO: translate to url properly!
|
# TODO: translate to url properly!
|
||||||
if recipient == "browser.search":
|
if recipient == "browser.search":
|
||||||
action = ActionSearch(
|
action = ActionSearch(
|
||||||
|
|||||||
@@ -555,7 +555,7 @@ async def generate_request(obj: GenerateReqInput, request: Request):
|
|||||||
async def generate_from_file_request(file: UploadFile, request: Request):
|
async def generate_from_file_request(file: UploadFile, request: Request):
|
||||||
"""Handle a generate request, this is purely to work with input_embeds."""
|
"""Handle a generate request, this is purely to work with input_embeds."""
|
||||||
content = await file.read()
|
content = await file.read()
|
||||||
input_embeds = json.loads(content.decode("utf-8"))
|
input_embeds = orjson.loads(content.decode("utf-8"))
|
||||||
|
|
||||||
obj = GenerateReqInput(
|
obj = GenerateReqInput(
|
||||||
input_embeds=input_embeds,
|
input_embeds=input_embeds,
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import uuid
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||||
|
|
||||||
|
import orjson
|
||||||
from fastapi import HTTPException, Request
|
from fastapi import HTTPException, Request
|
||||||
from fastapi.responses import ORJSONResponse, StreamingResponse
|
from fastapi.responses import ORJSONResponse, StreamingResponse
|
||||||
|
|
||||||
@@ -197,7 +198,7 @@ class OpenAIServingBase(ABC):
|
|||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
raw_labels = (
|
raw_labels = (
|
||||||
json.loads(raw_request.headers.get(header))
|
orjson.loads(raw_request.headers.get(header))
|
||||||
if raw_request and raw_request.headers.get(header)
|
if raw_request and raw_request.headers.get(header)
|
||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import time
|
|||||||
import uuid
|
import uuid
|
||||||
from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, List, Optional, Union
|
from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, List, Optional, Union
|
||||||
|
|
||||||
|
import orjson
|
||||||
from fastapi import Request
|
from fastapi import Request
|
||||||
from fastapi.responses import ORJSONResponse, StreamingResponse
|
from fastapi.responses import ORJSONResponse, StreamingResponse
|
||||||
from jsonschema import Draft202012Validator, SchemaError
|
from jsonschema import Draft202012Validator, SchemaError
|
||||||
@@ -285,7 +286,7 @@ class OpenAIServingChat(OpenAIServingBase):
|
|||||||
if "arguments" in item["function"] and isinstance(
|
if "arguments" in item["function"] and isinstance(
|
||||||
item["function"]["arguments"], str
|
item["function"]["arguments"], str
|
||||||
):
|
):
|
||||||
item["function"]["arguments"] = json.loads(
|
item["function"]["arguments"] = orjson.loads(
|
||||||
item["function"]["arguments"]
|
item["function"]["arguments"]
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -860,7 +861,7 @@ class OpenAIServingChat(OpenAIServingBase):
|
|||||||
finish_reason["matched"] = None
|
finish_reason["matched"] = None
|
||||||
try:
|
try:
|
||||||
# For required tool choice, we expect a JSON array of tool calls
|
# For required tool choice, we expect a JSON array of tool calls
|
||||||
tool_call_data = json.loads(text)
|
tool_call_data = orjson.loads(text)
|
||||||
tool_calls = []
|
tool_calls = []
|
||||||
for i, tool in enumerate(tool_call_data):
|
for i, tool in enumerate(tool_call_data):
|
||||||
# Create a ToolCallItem from the JSON data
|
# Create a ToolCallItem from the JSON data
|
||||||
|
|||||||
@@ -5,7 +5,6 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import copy
|
import copy
|
||||||
import json
|
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
from contextlib import AsyncExitStack
|
from contextlib import AsyncExitStack
|
||||||
@@ -14,6 +13,7 @@ from typing import TYPE_CHECKING, Any, AsyncGenerator, AsyncIterator, Optional,
|
|||||||
|
|
||||||
import jinja2
|
import jinja2
|
||||||
import openai.types.responses as openai_responses_types
|
import openai.types.responses as openai_responses_types
|
||||||
|
import orjson
|
||||||
from fastapi import Request
|
from fastapi import Request
|
||||||
from fastapi.responses import ORJSONResponse
|
from fastapi.responses import ORJSONResponse
|
||||||
from openai.types.responses import (
|
from openai.types.responses import (
|
||||||
@@ -1061,7 +1061,7 @@ class OpenAIServingResponses(OpenAIServingChat):
|
|||||||
):
|
):
|
||||||
function_name = previous_item.recipient[len("browser.") :]
|
function_name = previous_item.recipient[len("browser.") :]
|
||||||
action = None
|
action = None
|
||||||
parsed_args = json.loads(previous_item.content[0].text)
|
parsed_args = ororjson.loads(previous_item.content[0].text)
|
||||||
if function_name == "search":
|
if function_name == "search":
|
||||||
action = openai_responses_types.response_function_web_search.ActionSearch(
|
action = openai_responses_types.response_function_web_search.ActionSearch(
|
||||||
type="search",
|
type="search",
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ import logging
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Any, Dict, List
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
|
import orjson
|
||||||
from partial_json_parser.core.exceptions import MalformedJSON
|
from partial_json_parser.core.exceptions import MalformedJSON
|
||||||
from partial_json_parser.core.options import Allow
|
from partial_json_parser.core.options import Allow
|
||||||
|
|
||||||
@@ -96,7 +97,7 @@ class BaseFormatDetector(ABC):
|
|||||||
Parses the text in one go. Returns success=True if the format matches, otherwise False.
|
Parses the text in one go. Returns success=True if the format matches, otherwise False.
|
||||||
Note that leftover_text here represents "content that this parser will not consume further".
|
Note that leftover_text here represents "content that this parser will not consume further".
|
||||||
"""
|
"""
|
||||||
action = json.loads(text)
|
action = orjson.loads(text)
|
||||||
return StreamingParseResult(calls=self.parse_base_json(action, tools))
|
return StreamingParseResult(calls=self.parse_base_json(action, tools))
|
||||||
|
|
||||||
def _ends_with_partial_token(self, buffer: str, bot_token: str) -> int:
|
def _ends_with_partial_token(self, buffer: str, bot_token: str) -> int:
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ from json import JSONDecodeError, JSONDecoder
|
|||||||
from json.decoder import WHITESPACE
|
from json.decoder import WHITESPACE
|
||||||
from typing import Any, List, Literal, Optional, Tuple, Union
|
from typing import Any, List, Literal, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import orjson
|
||||||
import partial_json_parser
|
import partial_json_parser
|
||||||
from partial_json_parser.core.options import Allow
|
from partial_json_parser.core.options import Allow
|
||||||
|
|
||||||
@@ -51,7 +52,7 @@ def _partial_json_loads(input_str: str, flags: Allow) -> Tuple[Any, int]:
|
|||||||
|
|
||||||
def _is_complete_json(input_str: str) -> bool:
|
def _is_complete_json(input_str: str) -> bool:
|
||||||
try:
|
try:
|
||||||
json.loads(input_str)
|
orjson.loads(input_str)
|
||||||
return True
|
return True
|
||||||
except JSONDecodeError:
|
except JSONDecodeError:
|
||||||
return False
|
return False
|
||||||
|
|||||||
@@ -34,6 +34,7 @@ from http import HTTPStatus
|
|||||||
from typing import Any, Awaitable, Dict, List, Optional, Tuple, Union
|
from typing import Any, Awaitable, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import fastapi
|
import fastapi
|
||||||
|
import orjson
|
||||||
import torch
|
import torch
|
||||||
import uvloop
|
import uvloop
|
||||||
import zmq
|
import zmq
|
||||||
@@ -157,7 +158,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
|
|||||||
self.log_requests = server_args.log_requests
|
self.log_requests = server_args.log_requests
|
||||||
self.log_requests_level = server_args.log_requests_level
|
self.log_requests_level = server_args.log_requests_level
|
||||||
self.preferred_sampling_params = (
|
self.preferred_sampling_params = (
|
||||||
json.loads(server_args.preferred_sampling_params)
|
orjson.loads(server_args.preferred_sampling_params)
|
||||||
if server_args.preferred_sampling_params
|
if server_args.preferred_sampling_params
|
||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ from functools import lru_cache
|
|||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
import dill
|
import dill
|
||||||
|
import orjson
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
@@ -12,7 +13,7 @@ def _cache_from_str(json_str: str):
|
|||||||
"""Deserialize a json string to a Callable object.
|
"""Deserialize a json string to a Callable object.
|
||||||
This function is cached to avoid redundant deserialization.
|
This function is cached to avoid redundant deserialization.
|
||||||
"""
|
"""
|
||||||
data = json.loads(json_str)
|
data = orjson.loads(json_str)
|
||||||
return dill.loads(bytes.fromhex(data["callable"]))
|
return dill.loads(bytes.fromhex(data["callable"]))
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -22,6 +22,8 @@ import random
|
|||||||
import tempfile
|
import tempfile
|
||||||
from typing import Dict, List, Literal, Optional, Union
|
from typing import Dict, List, Literal, Optional, Union
|
||||||
|
|
||||||
|
import orjson
|
||||||
|
|
||||||
from sglang.srt.connector import ConnectorType
|
from sglang.srt.connector import ConnectorType
|
||||||
from sglang.srt.function_call.function_call_parser import FunctionCallParser
|
from sglang.srt.function_call.function_call_parser import FunctionCallParser
|
||||||
from sglang.srt.lora.lora_registry import LoRARef
|
from sglang.srt.lora.lora_registry import LoRARef
|
||||||
@@ -3041,7 +3043,7 @@ class ServerArgs:
|
|||||||
self.model_path,
|
self.model_path,
|
||||||
trust_remote_code=self.trust_remote_code,
|
trust_remote_code=self.trust_remote_code,
|
||||||
revision=self.revision,
|
revision=self.revision,
|
||||||
model_override_args=json.loads(self.json_model_override_args),
|
model_override_args=orjson.loads(self.json_model_override_args),
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
return hf_config
|
return hf_config
|
||||||
|
|||||||
@@ -12,7 +12,6 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
"""Common utilities."""
|
"""Common utilities."""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
@@ -70,6 +69,7 @@ from typing import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import orjson
|
||||||
import psutil
|
import psutil
|
||||||
import pybase64
|
import pybase64
|
||||||
import requests
|
import requests
|
||||||
@@ -1112,7 +1112,7 @@ def configure_logger(server_args, prefix: str = ""):
|
|||||||
f"{SGLANG_LOGGING_CONFIG_PATH} but it does not exist!"
|
f"{SGLANG_LOGGING_CONFIG_PATH} but it does not exist!"
|
||||||
)
|
)
|
||||||
with open(SGLANG_LOGGING_CONFIG_PATH, encoding="utf-8") as file:
|
with open(SGLANG_LOGGING_CONFIG_PATH, encoding="utf-8") as file:
|
||||||
custom_config = json.loads(file.read())
|
custom_config = orjson.loads(file.read())
|
||||||
logging.config.dictConfig(custom_config)
|
logging.config.dictConfig(custom_config)
|
||||||
return
|
return
|
||||||
format = f"[%(asctime)s{prefix}] %(message)s"
|
format = f"[%(asctime)s{prefix}] %(message)s"
|
||||||
@@ -2525,9 +2525,9 @@ def log_info_on_rank0(logger, msg):
|
|||||||
|
|
||||||
def load_json_config(data: str):
|
def load_json_config(data: str):
|
||||||
try:
|
try:
|
||||||
return json.loads(data)
|
return orjson.loads(data)
|
||||||
except JSONDecodeError:
|
except JSONDecodeError:
|
||||||
return json.loads(Path(data).read_text())
|
return orjson.loads(Path(data).read_text())
|
||||||
|
|
||||||
|
|
||||||
def dispose_tensor(x: torch.Tensor):
|
def dispose_tensor(x: torch.Tensor):
|
||||||
@@ -3236,7 +3236,7 @@ def numa_bind_to_node(node: int):
|
|||||||
|
|
||||||
def json_list_type(value):
|
def json_list_type(value):
|
||||||
try:
|
try:
|
||||||
return json.loads(value)
|
return orjson.loads(value)
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
raise argparse.ArgumentTypeError(
|
raise argparse.ArgumentTypeError(
|
||||||
f"Invalid JSON list: {value}. Please provide a valid JSON list."
|
f"Invalid JSON list: {value}. Please provide a valid JSON list."
|
||||||
|
|||||||
Reference in New Issue
Block a user