Clean up (#422)
This commit is contained in:
@@ -10,9 +10,12 @@ from io import BytesIO
|
||||
from typing import List, Optional
|
||||
|
||||
import numpy as np
|
||||
import pydantic
|
||||
import requests
|
||||
import torch
|
||||
from packaging import version as pkg_version
|
||||
from pydantic import BaseModel
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
|
||||
show_time_cost = False
|
||||
time_infos = {}
|
||||
@@ -120,7 +123,7 @@ def check_port(port):
|
||||
return False
|
||||
|
||||
|
||||
def handle_port_init(
|
||||
def allocate_init_ports(
|
||||
port: Optional[int] = None,
|
||||
additional_ports: Optional[List[int]] = None,
|
||||
tp_size: int = 1,
|
||||
@@ -159,8 +162,6 @@ def get_exception_traceback():
|
||||
|
||||
|
||||
def get_int_token_logit_bias(tokenizer, vocab_size):
|
||||
from transformers import LlamaTokenizer, LlamaTokenizerFast
|
||||
|
||||
# a bug when model's vocab size > tokenizer.vocab_size
|
||||
vocab_size = tokenizer.vocab_size
|
||||
logit_bias = np.zeros(vocab_size, dtype=np.float32)
|
||||
@@ -281,3 +282,32 @@ def assert_pkg_version(pkg: str, min_version: str):
|
||||
)
|
||||
except PackageNotFoundError:
|
||||
raise Exception(f"{pkg} with minimum required version {min_version} is not installed")
|
||||
|
||||
|
||||
API_KEY_HEADER_NAME = "X-API-Key"
|
||||
|
||||
|
||||
class APIKeyValidatorMiddleware(BaseHTTPMiddleware):
|
||||
def __init__(self, app, api_key: str):
|
||||
super().__init__(app)
|
||||
self.api_key = api_key
|
||||
|
||||
async def dispatch(self, request, call_next):
|
||||
# extract API key from the request headers
|
||||
api_key_header = request.headers.get(API_KEY_HEADER_NAME)
|
||||
if not api_key_header or api_key_header != self.api_key:
|
||||
return JSONResponse(
|
||||
status_code=403,
|
||||
content={"detail": "Invalid API Key"},
|
||||
)
|
||||
response = await call_next(request)
|
||||
return response
|
||||
|
||||
# FIXME: Remove this once we drop support for pydantic 1.x
|
||||
IS_PYDANTIC_1 = int(pydantic.VERSION.split(".")[0]) == 1
|
||||
|
||||
|
||||
def jsonify_pydantic_model(obj: BaseModel):
|
||||
if IS_PYDANTIC_1:
|
||||
return obj.json(ensure_ascii=False)
|
||||
return obj.model_dump_json()
|
||||
|
||||
Reference in New Issue
Block a user