ORJson. Faster Json serialization (#1694)

This commit is contained in:
Michael Feil
2024-10-17 08:03:08 -07:00
committed by GitHub
parent b170930534
commit e5db40dcbc

View File

@@ -28,7 +28,9 @@ import os
import threading import threading
import time import time
from http import HTTPStatus from http import HTTPStatus
from typing import Dict, List, Optional, Union from typing import AsyncIterator, Dict, List, Optional, Union
import orjson
# Fix a bug of Python threading # Fix a bug of Python threading
setattr(threading, "_register_atexit", lambda *args, **kwargs: None) setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
@@ -192,14 +194,18 @@ async def generate_request(obj: GenerateReqInput, request: Request):
"""Handle a generate request.""" """Handle a generate request."""
if obj.stream: if obj.stream:
async def stream_results(): async def stream_results() -> AsyncIterator[bytes]:
try: try:
async for out in tokenizer_manager.generate_request(obj, request): async for out in tokenizer_manager.generate_request(obj, request):
yield f"data: {json.dumps(out, ensure_ascii=False)}\n\n" yield b"data: " + orjson.dumps(
out, option=orjson.OPT_NON_STR_KEYS
) + b"\n\n"
except ValueError as e: except ValueError as e:
out = {"error": {"message": str(e)}} out = {"error": {"message": str(e)}}
yield f"data: {json.dumps(out, ensure_ascii=False)}\n\n" yield b"data: " + orjson.dumps(
yield "data: [DONE]\n\n" out, option=orjson.OPT_NON_STR_KEYS
) + b"\n\n"
yield b"data: [DONE]\n\n"
return StreamingResponse( return StreamingResponse(
stream_results(), stream_results(),
@@ -260,13 +266,13 @@ async def openai_v1_chat_completions(raw_request: Request):
return await v1_chat_completions(tokenizer_manager, raw_request) return await v1_chat_completions(tokenizer_manager, raw_request)
@app.post("/v1/embeddings") @app.post("/v1/embeddings", response_class=ORJSONResponse)
async def openai_v1_embeddings(raw_request: Request): async def openai_v1_embeddings(raw_request: Request):
response = await v1_embeddings(tokenizer_manager, raw_request) response = await v1_embeddings(tokenizer_manager, raw_request)
return response return response
@app.get("/v1/models") @app.get("/v1/models", response_class=ORJSONResponse)
def available_models(): def available_models():
"""Show available models.""" """Show available models."""
served_model_names = [tokenizer_manager.served_model_name] served_model_names = [tokenizer_manager.served_model_name]