This commit is contained in:
Lianmin Zheng
2024-05-11 20:55:00 -07:00
committed by GitHub
parent 09deb20dee
commit 7023f413c6
8 changed files with 248 additions and 280 deletions

View File

@@ -10,9 +10,12 @@ from io import BytesIO
from typing import List, Optional
import numpy as np
import pydantic
import requests
import torch
from packaging import version as pkg_version
from pydantic import BaseModel
from starlette.middleware.base import BaseHTTPMiddleware
show_time_cost = False
time_infos = {}
@@ -120,7 +123,7 @@ def check_port(port):
return False
def handle_port_init(
def allocate_init_ports(
port: Optional[int] = None,
additional_ports: Optional[List[int]] = None,
tp_size: int = 1,
@@ -159,8 +162,6 @@ def get_exception_traceback():
def get_int_token_logit_bias(tokenizer, vocab_size):
from transformers import LlamaTokenizer, LlamaTokenizerFast
# a bug when model's vocab size > tokenizer.vocab_size
vocab_size = tokenizer.vocab_size
logit_bias = np.zeros(vocab_size, dtype=np.float32)
@@ -281,3 +282,32 @@ def assert_pkg_version(pkg: str, min_version: str):
)
except PackageNotFoundError:
raise Exception(f"{pkg} with minimum required version {min_version} is not installed")
API_KEY_HEADER_NAME = "X-API-Key"
class APIKeyValidatorMiddleware(BaseHTTPMiddleware):
def __init__(self, app, api_key: str):
super().__init__(app)
self.api_key = api_key
async def dispatch(self, request, call_next):
# extract API key from the request headers
api_key_header = request.headers.get(API_KEY_HEADER_NAME)
if not api_key_header or api_key_header != self.api_key:
return JSONResponse(
status_code=403,
content={"detail": "Invalid API Key"},
)
response = await call_next(request)
return response
# FIXME: Remove this once we drop support for pydantic 1.x
IS_PYDANTIC_1 = int(pydantic.VERSION.split(".")[0]) == 1
def jsonify_pydantic_model(obj: BaseModel):
if IS_PYDANTIC_1:
return obj.json(ensure_ascii=False)
return obj.model_dump_json()