add orjson for jsonresponse (#1688)

This commit is contained in:
Michael Feil
2024-10-16 18:14:30 -07:00
committed by GitHub
parent ecb8bad276
commit b0facb3316
4 changed files with 11 additions and 11 deletions

View File

@@ -21,7 +21,7 @@ dependencies = [
[project.optional-dependencies] [project.optional-dependencies]
runtime_common = ["aiohttp", "decord", "fastapi", "hf_transfer", "huggingface_hub", "interegular", runtime_common = ["aiohttp", "decord", "fastapi", "hf_transfer", "huggingface_hub", "interegular",
"packaging", "pillow", "psutil", "pydantic", "python-multipart", "orjson", "packaging", "pillow", "psutil", "pydantic", "python-multipart",
"torchao", "uvicorn", "uvloop", "zmq", "torchao", "uvicorn", "uvloop", "zmq",
"outlines>=0.0.44", "modelscope"] "outlines>=0.0.44", "modelscope"]
# xpu is not enabled in public vllm and torch whl, # xpu is not enabled in public vllm and torch whl,

View File

@@ -25,7 +25,7 @@ from http import HTTPStatus
from typing import Dict, List from typing import Dict, List
from fastapi import HTTPException, Request, UploadFile from fastapi import HTTPException, Request, UploadFile
from fastapi.responses import JSONResponse, StreamingResponse from fastapi.responses import ORJSONResponse, StreamingResponse
from pydantic import ValidationError from pydantic import ValidationError
try: try:
@@ -101,7 +101,7 @@ def create_error_response(
status_code: HTTPStatus = HTTPStatus.BAD_REQUEST, status_code: HTTPStatus = HTTPStatus.BAD_REQUEST,
): ):
error = ErrorResponse(message=message, type=err_type, code=status_code.value) error = ErrorResponse(message=message, type=err_type, code=status_code.value)
return JSONResponse(content=error.model_dump(), status_code=error.code) return ORJSONResponse(content=error.model_dump(), status_code=error.code)
def create_streaming_error_response( def create_streaming_error_response(

View File

@@ -40,7 +40,7 @@ import uvicorn
import uvloop import uvloop
from fastapi import FastAPI, File, Form, Request, UploadFile from fastapi import FastAPI, File, Form, Request, UploadFile
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, Response, StreamingResponse from fastapi.responses import ORJSONResponse, Response, StreamingResponse
from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint
from sglang.srt.hf_transformers_utils import get_tokenizer from sglang.srt.hf_transformers_utils import get_tokenizer
@@ -176,12 +176,12 @@ async def update_weights(obj: UpdateWeightReqInput, request: Request):
success, message = await tokenizer_manager.update_weights(obj, request) success, message = await tokenizer_manager.update_weights(obj, request)
content = {"success": success, "message": message} content = {"success": success, "message": message}
if success: if success:
return JSONResponse( return ORJSONResponse(
content, content,
status_code=HTTPStatus.OK, status_code=HTTPStatus.OK,
) )
else: else:
return JSONResponse( return ORJSONResponse(
content, content,
status_code=HTTPStatus.BAD_REQUEST, status_code=HTTPStatus.BAD_REQUEST,
) )
@@ -211,7 +211,7 @@ async def generate_request(obj: GenerateReqInput, request: Request):
ret = await tokenizer_manager.generate_request(obj, request).__anext__() ret = await tokenizer_manager.generate_request(obj, request).__anext__()
return ret return ret
except ValueError as e: except ValueError as e:
return JSONResponse( return ORJSONResponse(
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST {"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
) )
@@ -226,7 +226,7 @@ async def encode_request(obj: EmbeddingReqInput, request: Request):
ret = await tokenizer_manager.generate_request(obj, request).__anext__() ret = await tokenizer_manager.generate_request(obj, request).__anext__()
return ret return ret
except ValueError as e: except ValueError as e:
return JSONResponse( return ORJSONResponse(
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST {"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
) )
@@ -241,7 +241,7 @@ async def judge_request(obj: RewardReqInput, request: Request):
ret = await tokenizer_manager.generate_request(obj, request).__anext__() ret = await tokenizer_manager.generate_request(obj, request).__anext__()
return ret return ret
except ValueError as e: except ValueError as e:
return JSONResponse( return ORJSONResponse(
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST {"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
) )

View File

@@ -35,7 +35,7 @@ import psutil
import requests import requests
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from fastapi.responses import JSONResponse from fastapi.responses import ORJSONResponse
from packaging import version as pkg_version from packaging import version as pkg_version
from torch import nn from torch import nn
from torch.profiler import ProfilerActivity, profile, record_function from torch.profiler import ProfilerActivity, profile, record_function
@@ -566,7 +566,7 @@ def add_api_key_middleware(app, api_key: str):
if request.url.path.startswith("/health"): if request.url.path.startswith("/health"):
return await call_next(request) return await call_next(request)
if request.headers.get("Authorization") != "Bearer " + api_key: if request.headers.get("Authorization") != "Bearer " + api_key:
return JSONResponse(content={"error": "Unauthorized"}, status_code=401) return ORJSONResponse(content={"error": "Unauthorized"}, status_code=401)
return await call_next(request) return await call_next(request)