Make API Key OpenAI-compatible (#917)
This commit is contained in:
@@ -15,7 +15,6 @@ class RuntimeEndpoint(BaseBackend):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
base_url: str,
|
base_url: str,
|
||||||
auth_token: Optional[str] = None,
|
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
verify: Optional[str] = None,
|
verify: Optional[str] = None,
|
||||||
):
|
):
|
||||||
@@ -23,13 +22,11 @@ class RuntimeEndpoint(BaseBackend):
|
|||||||
self.support_concate_and_append = True
|
self.support_concate_and_append = True
|
||||||
|
|
||||||
self.base_url = base_url
|
self.base_url = base_url
|
||||||
self.auth_token = auth_token
|
|
||||||
self.api_key = api_key
|
self.api_key = api_key
|
||||||
self.verify = verify
|
self.verify = verify
|
||||||
|
|
||||||
res = http_request(
|
res = http_request(
|
||||||
self.base_url + "/get_model_info",
|
self.base_url + "/get_model_info",
|
||||||
auth_token=self.auth_token,
|
|
||||||
api_key=self.api_key,
|
api_key=self.api_key,
|
||||||
verify=self.verify,
|
verify=self.verify,
|
||||||
)
|
)
|
||||||
@@ -67,7 +64,6 @@ class RuntimeEndpoint(BaseBackend):
|
|||||||
res = http_request(
|
res = http_request(
|
||||||
self.base_url + "/generate",
|
self.base_url + "/generate",
|
||||||
json={"text": prefix_str, "sampling_params": {"max_new_tokens": 0}},
|
json={"text": prefix_str, "sampling_params": {"max_new_tokens": 0}},
|
||||||
auth_token=self.auth_token,
|
|
||||||
api_key=self.api_key,
|
api_key=self.api_key,
|
||||||
verify=self.verify,
|
verify=self.verify,
|
||||||
)
|
)
|
||||||
@@ -79,7 +75,6 @@ class RuntimeEndpoint(BaseBackend):
|
|||||||
res = http_request(
|
res = http_request(
|
||||||
self.base_url + "/generate",
|
self.base_url + "/generate",
|
||||||
json=data,
|
json=data,
|
||||||
auth_token=self.auth_token,
|
|
||||||
api_key=self.api_key,
|
api_key=self.api_key,
|
||||||
verify=self.verify,
|
verify=self.verify,
|
||||||
)
|
)
|
||||||
@@ -91,7 +86,6 @@ class RuntimeEndpoint(BaseBackend):
|
|||||||
res = http_request(
|
res = http_request(
|
||||||
self.base_url + "/generate",
|
self.base_url + "/generate",
|
||||||
json=data,
|
json=data,
|
||||||
auth_token=self.auth_token,
|
|
||||||
api_key=self.api_key,
|
api_key=self.api_key,
|
||||||
verify=self.verify,
|
verify=self.verify,
|
||||||
)
|
)
|
||||||
@@ -139,7 +133,6 @@ class RuntimeEndpoint(BaseBackend):
|
|||||||
res = http_request(
|
res = http_request(
|
||||||
self.base_url + "/generate",
|
self.base_url + "/generate",
|
||||||
json=data,
|
json=data,
|
||||||
auth_token=self.auth_token,
|
|
||||||
api_key=self.api_key,
|
api_key=self.api_key,
|
||||||
verify=self.verify,
|
verify=self.verify,
|
||||||
)
|
)
|
||||||
@@ -193,7 +186,6 @@ class RuntimeEndpoint(BaseBackend):
|
|||||||
self.base_url + "/generate",
|
self.base_url + "/generate",
|
||||||
json=data,
|
json=data,
|
||||||
stream=True,
|
stream=True,
|
||||||
auth_token=self.auth_token,
|
|
||||||
api_key=self.api_key,
|
api_key=self.api_key,
|
||||||
verify=self.verify,
|
verify=self.verify,
|
||||||
)
|
)
|
||||||
@@ -225,7 +217,6 @@ class RuntimeEndpoint(BaseBackend):
|
|||||||
res = http_request(
|
res = http_request(
|
||||||
self.base_url + "/generate",
|
self.base_url + "/generate",
|
||||||
json=data,
|
json=data,
|
||||||
auth_token=self.auth_token,
|
|
||||||
api_key=self.api_key,
|
api_key=self.api_key,
|
||||||
verify=self.verify,
|
verify=self.verify,
|
||||||
)
|
)
|
||||||
@@ -243,7 +234,6 @@ class RuntimeEndpoint(BaseBackend):
|
|||||||
res = http_request(
|
res = http_request(
|
||||||
self.base_url + "/generate",
|
self.base_url + "/generate",
|
||||||
json=data,
|
json=data,
|
||||||
auth_token=self.auth_token,
|
|
||||||
api_key=self.api_key,
|
api_key=self.api_key,
|
||||||
verify=self.verify,
|
verify=self.verify,
|
||||||
)
|
)
|
||||||
@@ -267,7 +257,6 @@ class RuntimeEndpoint(BaseBackend):
|
|||||||
res = http_request(
|
res = http_request(
|
||||||
self.base_url + "/concate_and_append_request",
|
self.base_url + "/concate_and_append_request",
|
||||||
json={"src_rids": src_rids, "dst_rid": dst_rid},
|
json={"src_rids": src_rids, "dst_rid": dst_rid},
|
||||||
auth_token=self.auth_token,
|
|
||||||
api_key=self.api_key,
|
api_key=self.api_key,
|
||||||
verify=self.verify,
|
verify=self.verify,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -67,13 +67,13 @@ from sglang.srt.openai_api.adapter import (
|
|||||||
from sglang.srt.openai_api.protocol import ModelCard, ModelList
|
from sglang.srt.openai_api.protocol import ModelCard, ModelList
|
||||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||||
from sglang.srt.utils import (
|
from sglang.srt.utils import (
|
||||||
API_KEY_HEADER_NAME,
|
add_api_key_middleware,
|
||||||
APIKeyValidatorMiddleware,
|
|
||||||
allocate_init_ports,
|
allocate_init_ports,
|
||||||
assert_pkg_version,
|
assert_pkg_version,
|
||||||
enable_show_time_cost,
|
enable_show_time_cost,
|
||||||
kill_child_process,
|
kill_child_process,
|
||||||
maybe_set_triton_cache_manager,
|
maybe_set_triton_cache_manager,
|
||||||
|
set_torch_compile_config,
|
||||||
set_ulimit,
|
set_ulimit,
|
||||||
)
|
)
|
||||||
from sglang.utils import get_exception_traceback
|
from sglang.utils import get_exception_traceback
|
||||||
@@ -158,6 +158,16 @@ async def openai_v1_chat_completions(raw_request: Request):
|
|||||||
return await v1_chat_completions(tokenizer_manager, raw_request)
|
return await v1_chat_completions(tokenizer_manager, raw_request)
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/v1/models")
|
||||||
|
def available_models():
|
||||||
|
"""Show available models."""
|
||||||
|
served_model_names = [tokenizer_manager.served_model_name]
|
||||||
|
model_cards = []
|
||||||
|
for served_model_name in served_model_names:
|
||||||
|
model_cards.append(ModelCard(id=served_model_name, root=served_model_name))
|
||||||
|
return ModelList(data=model_cards)
|
||||||
|
|
||||||
|
|
||||||
@app.post("/v1/files")
|
@app.post("/v1/files")
|
||||||
async def openai_v1_files(file: UploadFile = File(...), purpose: str = Form("batch")):
|
async def openai_v1_files(file: UploadFile = File(...), purpose: str = Form("batch")):
|
||||||
return await v1_files_create(
|
return await v1_files_create(
|
||||||
@@ -187,69 +197,11 @@ async def retrieve_file_content(file_id: str):
|
|||||||
return await v1_retrieve_file_content(file_id)
|
return await v1_retrieve_file_content(file_id)
|
||||||
|
|
||||||
|
|
||||||
@app.get("/v1/models")
|
|
||||||
def available_models():
|
|
||||||
"""Show available models."""
|
|
||||||
served_model_names = [tokenizer_manager.served_model_name]
|
|
||||||
model_cards = []
|
|
||||||
for served_model_name in served_model_names:
|
|
||||||
model_cards.append(ModelCard(id=served_model_name, root=served_model_name))
|
|
||||||
return ModelList(data=model_cards)
|
|
||||||
|
|
||||||
|
|
||||||
def _set_torch_compile_config():
|
|
||||||
# The following configurations are for torch compile optimizations
|
|
||||||
import torch._dynamo.config
|
|
||||||
import torch._inductor.config
|
|
||||||
|
|
||||||
torch._inductor.config.coordinate_descent_tuning = True
|
|
||||||
torch._inductor.config.triton.unique_kernel_names = True
|
|
||||||
torch._inductor.config.fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future
|
|
||||||
|
|
||||||
# FIXME: tmp workaround
|
|
||||||
torch._dynamo.config.accumulated_cache_size_limit = 256
|
|
||||||
|
|
||||||
|
|
||||||
def set_envs_and_config(server_args: ServerArgs):
|
|
||||||
# Set global environments
|
|
||||||
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
|
||||||
os.environ["NCCL_CUMEM_ENABLE"] = "0"
|
|
||||||
os.environ["NCCL_NVLS_ENABLE"] = "0"
|
|
||||||
os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
|
|
||||||
|
|
||||||
# Set ulimit
|
|
||||||
set_ulimit()
|
|
||||||
|
|
||||||
# Enable show time cost for debugging
|
|
||||||
if server_args.show_time_cost:
|
|
||||||
enable_show_time_cost()
|
|
||||||
|
|
||||||
# Disable disk cache
|
|
||||||
if server_args.disable_disk_cache:
|
|
||||||
disable_cache()
|
|
||||||
|
|
||||||
# Fix triton bugs
|
|
||||||
if server_args.tp_size * server_args.dp_size > 1:
|
|
||||||
# FIXME: remove this after https://github.com/triton-lang/triton/pull/4295 is used as a dependency.
|
|
||||||
maybe_set_triton_cache_manager()
|
|
||||||
|
|
||||||
# Set torch compile config
|
|
||||||
if server_args.enable_torch_compile:
|
|
||||||
_set_torch_compile_config()
|
|
||||||
|
|
||||||
# Set global chat template
|
|
||||||
if server_args.chat_template:
|
|
||||||
# TODO: replace this with huggingface transformers template
|
|
||||||
load_chat_template_for_openai_api(server_args.chat_template)
|
|
||||||
|
|
||||||
|
|
||||||
def launch_server(
|
def launch_server(
|
||||||
server_args: ServerArgs,
|
server_args: ServerArgs,
|
||||||
model_overide_args: Optional[dict] = None,
|
model_overide_args: Optional[dict] = None,
|
||||||
pipe_finish_writer: Optional[mp.connection.Connection] = None,
|
pipe_finish_writer: Optional[mp.connection.Connection] = None,
|
||||||
):
|
):
|
||||||
server_args.check_server_args()
|
|
||||||
|
|
||||||
"""Launch an HTTP server."""
|
"""Launch an HTTP server."""
|
||||||
global tokenizer_manager
|
global tokenizer_manager
|
||||||
|
|
||||||
@@ -258,16 +210,8 @@ def launch_server(
|
|||||||
format="%(message)s",
|
format="%(message)s",
|
||||||
)
|
)
|
||||||
|
|
||||||
if not server_args.disable_flashinfer:
|
server_args.check_server_args()
|
||||||
assert_pkg_version(
|
_set_envs_and_config(server_args)
|
||||||
"flashinfer",
|
|
||||||
"0.1.3",
|
|
||||||
"Please uninstall the old version and "
|
|
||||||
"reinstall the latest version by following the instructions "
|
|
||||||
"at https://docs.flashinfer.ai/installation.html.",
|
|
||||||
)
|
|
||||||
|
|
||||||
set_envs_and_config(server_args)
|
|
||||||
|
|
||||||
# Allocate ports
|
# Allocate ports
|
||||||
server_args.port, server_args.additional_ports = allocate_init_ports(
|
server_args.port, server_args.additional_ports = allocate_init_ports(
|
||||||
@@ -284,7 +228,7 @@ def launch_server(
|
|||||||
)
|
)
|
||||||
logger.info(f"{server_args=}")
|
logger.info(f"{server_args=}")
|
||||||
|
|
||||||
# Handle multi-node tensor parallelism
|
# Launch processes for multi-node tensor parallelism
|
||||||
if server_args.nnodes > 1:
|
if server_args.nnodes > 1:
|
||||||
if server_args.node_rank != 0:
|
if server_args.node_rank != 0:
|
||||||
tp_size_local = server_args.tp_size // server_args.nnodes
|
tp_size_local = server_args.tp_size // server_args.nnodes
|
||||||
@@ -349,8 +293,9 @@ def launch_server(
|
|||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
assert proc_controller.is_alive() and proc_detoken.is_alive()
|
assert proc_controller.is_alive() and proc_detoken.is_alive()
|
||||||
|
|
||||||
if server_args.api_key and server_args.api_key != "":
|
# Add api key authorization
|
||||||
app.add_middleware(APIKeyValidatorMiddleware, api_key=server_args.api_key)
|
if server_args.api_key:
|
||||||
|
add_api_key_middleware(app, server_args.api_key)
|
||||||
|
|
||||||
# Send a warmup request
|
# Send a warmup request
|
||||||
t = threading.Thread(
|
t = threading.Thread(
|
||||||
@@ -372,15 +317,58 @@ def launch_server(
|
|||||||
t.join()
|
t.join()
|
||||||
|
|
||||||
|
|
||||||
|
def _set_envs_and_config(server_args: ServerArgs):
|
||||||
|
# Set global environments
|
||||||
|
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
||||||
|
os.environ["NCCL_CUMEM_ENABLE"] = "0"
|
||||||
|
os.environ["NCCL_NVLS_ENABLE"] = "0"
|
||||||
|
os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
|
||||||
|
|
||||||
|
# Set ulimit
|
||||||
|
set_ulimit()
|
||||||
|
|
||||||
|
# Enable show time cost for debugging
|
||||||
|
if server_args.show_time_cost:
|
||||||
|
enable_show_time_cost()
|
||||||
|
|
||||||
|
# Disable disk cache
|
||||||
|
if server_args.disable_disk_cache:
|
||||||
|
disable_cache()
|
||||||
|
|
||||||
|
# Fix triton bugs
|
||||||
|
if server_args.tp_size * server_args.dp_size > 1:
|
||||||
|
# FIXME: remove this after https://github.com/triton-lang/triton/pull/4295 is used as a dependency.
|
||||||
|
maybe_set_triton_cache_manager()
|
||||||
|
|
||||||
|
# Set torch compile config
|
||||||
|
if server_args.enable_torch_compile:
|
||||||
|
set_torch_compile_config()
|
||||||
|
|
||||||
|
# Set global chat template
|
||||||
|
if server_args.chat_template:
|
||||||
|
# TODO: replace this with huggingface transformers template
|
||||||
|
load_chat_template_for_openai_api(server_args.chat_template)
|
||||||
|
|
||||||
|
# Check flashinfer version
|
||||||
|
if not server_args.disable_flashinfer:
|
||||||
|
assert_pkg_version(
|
||||||
|
"flashinfer",
|
||||||
|
"0.1.3",
|
||||||
|
"Please uninstall the old version and "
|
||||||
|
"reinstall the latest version by following the instructions "
|
||||||
|
"at https://docs.flashinfer.ai/installation.html.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _wait_and_warmup(server_args, pipe_finish_writer):
|
def _wait_and_warmup(server_args, pipe_finish_writer):
|
||||||
headers = {}
|
headers = {}
|
||||||
url = server_args.url()
|
url = server_args.url()
|
||||||
if server_args.api_key:
|
if server_args.api_key:
|
||||||
headers[API_KEY_HEADER_NAME] = server_args.api_key
|
headers["Authorization"] = f"Bearer {server_args.api_key}"
|
||||||
|
|
||||||
# Wait until the server is launched
|
# Wait until the server is launched
|
||||||
for _ in range(120):
|
for _ in range(120):
|
||||||
time.sleep(0.5)
|
time.sleep(1)
|
||||||
try:
|
try:
|
||||||
requests.get(url + "/get_model_info", timeout=5, headers=headers)
|
requests.get(url + "/get_model_info", timeout=5, headers=headers)
|
||||||
break
|
break
|
||||||
|
|||||||
@@ -61,7 +61,7 @@ class ServerArgs:
|
|||||||
show_time_cost: bool = False
|
show_time_cost: bool = False
|
||||||
|
|
||||||
# Other
|
# Other
|
||||||
api_key: str = ""
|
api_key: Optional[str] = None
|
||||||
file_storage_pth: str = "SGlang_storage"
|
file_storage_pth: str = "SGlang_storage"
|
||||||
|
|
||||||
# Data parallelism
|
# Data parallelism
|
||||||
@@ -307,7 +307,7 @@ class ServerArgs:
|
|||||||
"--api-key",
|
"--api-key",
|
||||||
type=str,
|
type=str,
|
||||||
default=ServerArgs.api_key,
|
default=ServerArgs.api_key,
|
||||||
help="Set API key of the server.",
|
help="Set API key of the server. It is also used in the OpenAI API compatible server.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--file-storage-pth",
|
"--file-storage-pth",
|
||||||
|
|||||||
@@ -539,26 +539,6 @@ class CustomCacheManager(FileCacheManager):
|
|||||||
raise RuntimeError("Could not create or locate cache dir")
|
raise RuntimeError("Could not create or locate cache dir")
|
||||||
|
|
||||||
|
|
||||||
API_KEY_HEADER_NAME = "X-API-Key"
|
|
||||||
|
|
||||||
|
|
||||||
class APIKeyValidatorMiddleware(BaseHTTPMiddleware):
|
|
||||||
def __init__(self, app, api_key: str):
|
|
||||||
super().__init__(app)
|
|
||||||
self.api_key = api_key
|
|
||||||
|
|
||||||
async def dispatch(self, request, call_next):
|
|
||||||
# extract API key from the request headers
|
|
||||||
api_key_header = request.headers.get(API_KEY_HEADER_NAME)
|
|
||||||
if not api_key_header or api_key_header != self.api_key:
|
|
||||||
return JSONResponse(
|
|
||||||
status_code=403,
|
|
||||||
content={"detail": "Invalid API Key"},
|
|
||||||
)
|
|
||||||
response = await call_next(request)
|
|
||||||
return response
|
|
||||||
|
|
||||||
|
|
||||||
def get_ip_address(ifname):
|
def get_ip_address(ifname):
|
||||||
"""
|
"""
|
||||||
Get the IP address of a network interface.
|
Get the IP address of a network interface.
|
||||||
@@ -642,6 +622,19 @@ def receive_addrs(model_port_args, server_args):
|
|||||||
dist.destroy_process_group()
|
dist.destroy_process_group()
|
||||||
|
|
||||||
|
|
||||||
|
def set_torch_compile_config():
|
||||||
|
# The following configurations are for torch compile optimizations
|
||||||
|
import torch._dynamo.config
|
||||||
|
import torch._inductor.config
|
||||||
|
|
||||||
|
torch._inductor.config.coordinate_descent_tuning = True
|
||||||
|
torch._inductor.config.triton.unique_kernel_names = True
|
||||||
|
torch._inductor.config.fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future
|
||||||
|
|
||||||
|
# FIXME: tmp workaround
|
||||||
|
torch._dynamo.config.accumulated_cache_size_limit = 256
|
||||||
|
|
||||||
|
|
||||||
def set_ulimit(target_soft_limit=65535):
|
def set_ulimit(target_soft_limit=65535):
|
||||||
resource_type = resource.RLIMIT_NOFILE
|
resource_type = resource.RLIMIT_NOFILE
|
||||||
current_soft, current_hard = resource.getrlimit(resource_type)
|
current_soft, current_hard = resource.getrlimit(resource_type)
|
||||||
@@ -700,3 +693,15 @@ def monkey_patch_vllm_qvk_linear_loader():
|
|||||||
origin_weight_loader(self, param, loaded_weight, loaded_shard_id)
|
origin_weight_loader(self, param, loaded_weight, loaded_shard_id)
|
||||||
|
|
||||||
setattr(QKVParallelLinear, "weight_loader", weight_loader_srt)
|
setattr(QKVParallelLinear, "weight_loader", weight_loader_srt)
|
||||||
|
|
||||||
|
|
||||||
|
def add_api_key_middleware(app, api_key):
|
||||||
|
@app.middleware("http")
|
||||||
|
async def authentication(request, call_next):
|
||||||
|
if request.method == "OPTIONS":
|
||||||
|
return await call_next(request)
|
||||||
|
if request.url.path.startswith("/health"):
|
||||||
|
return await call_next(request)
|
||||||
|
if request.headers.get("Authorization") != "Bearer " + api_key:
|
||||||
|
return JSONResponse(content={"error": "Unauthorized"}, status_code=401)
|
||||||
|
return await call_next(request)
|
||||||
|
|||||||
@@ -391,7 +391,11 @@ def get_call_select(args: argparse.Namespace):
|
|||||||
|
|
||||||
|
|
||||||
def popen_launch_server(
|
def popen_launch_server(
|
||||||
model: str, base_url: str, timeout: float, other_args: tuple = ()
|
model: str,
|
||||||
|
base_url: str,
|
||||||
|
timeout: float,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
other_args: tuple = (),
|
||||||
):
|
):
|
||||||
_, host, port = base_url.split(":")
|
_, host, port = base_url.split(":")
|
||||||
host = host[2:]
|
host = host[2:]
|
||||||
@@ -408,12 +412,19 @@ def popen_launch_server(
|
|||||||
port,
|
port,
|
||||||
*other_args,
|
*other_args,
|
||||||
]
|
]
|
||||||
|
if api_key:
|
||||||
|
command += ["--api-key", api_key]
|
||||||
|
|
||||||
process = subprocess.Popen(command, stdout=None, stderr=None)
|
process = subprocess.Popen(command, stdout=None, stderr=None)
|
||||||
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
while time.time() - start_time < timeout:
|
while time.time() - start_time < timeout:
|
||||||
try:
|
try:
|
||||||
response = requests.get(f"{base_url}/v1/models")
|
headers = {
|
||||||
|
"Content-Type": "application/json; charset=utf-8",
|
||||||
|
"Authorization": f"Bearer {api_key}",
|
||||||
|
}
|
||||||
|
response = requests.get(f"{base_url}/v1/models", headers=headers)
|
||||||
if response.status_code == 200:
|
if response.status_code == 200:
|
||||||
return process
|
return process
|
||||||
except requests.RequestException:
|
except requests.RequestException:
|
||||||
|
|||||||
@@ -76,19 +76,13 @@ class HttpResponse:
|
|||||||
return self.resp.status
|
return self.resp.status
|
||||||
|
|
||||||
|
|
||||||
def http_request(
|
def http_request(url, json=None, stream=False, api_key=None, verify=None):
|
||||||
url, json=None, stream=False, auth_token=None, api_key=None, verify=None
|
|
||||||
):
|
|
||||||
"""A faster version of requests.post with low-level urllib API."""
|
"""A faster version of requests.post with low-level urllib API."""
|
||||||
headers = {"Content-Type": "application/json; charset=utf-8"}
|
headers = {"Content-Type": "application/json; charset=utf-8"}
|
||||||
|
|
||||||
# add the Authorization header if an auth token is provided
|
# add the Authorization header if an api key is provided
|
||||||
if auth_token is not None:
|
|
||||||
headers["Authorization"] = f"Bearer {auth_token}"
|
|
||||||
|
|
||||||
# add the API Key header if an API key is provided
|
|
||||||
if api_key is not None:
|
if api_key is not None:
|
||||||
headers["X-API-Key"] = api_key
|
headers["Authorization"] = f"Bearer {api_key}"
|
||||||
|
|
||||||
if stream:
|
if stream:
|
||||||
return requests.post(url, json=json, stream=True, headers=headers)
|
return requests.post(url, json=json, stream=True, headers=headers)
|
||||||
|
|||||||
@@ -13,7 +13,10 @@ class TestOpenAIServer(unittest.TestCase):
|
|||||||
def setUpClass(cls):
|
def setUpClass(cls):
|
||||||
cls.model = MODEL_NAME_FOR_TEST
|
cls.model = MODEL_NAME_FOR_TEST
|
||||||
cls.base_url = f"http://localhost:30000"
|
cls.base_url = f"http://localhost:30000"
|
||||||
cls.process = popen_launch_server(cls.model, cls.base_url, timeout=300)
|
cls.api_key = "sk-123456"
|
||||||
|
cls.process = popen_launch_server(
|
||||||
|
cls.model, cls.base_url, timeout=300, api_key=cls.api_key
|
||||||
|
)
|
||||||
cls.base_url += "/v1"
|
cls.base_url += "/v1"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -21,7 +24,7 @@ class TestOpenAIServer(unittest.TestCase):
|
|||||||
kill_child_process(cls.process.pid)
|
kill_child_process(cls.process.pid)
|
||||||
|
|
||||||
def run_completion(self, echo, logprobs, use_list_input):
|
def run_completion(self, echo, logprobs, use_list_input):
|
||||||
client = openai.Client(api_key="EMPTY", base_url=self.base_url)
|
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||||
prompt = "The capital of France is"
|
prompt = "The capital of France is"
|
||||||
|
|
||||||
if use_list_input:
|
if use_list_input:
|
||||||
@@ -63,7 +66,7 @@ class TestOpenAIServer(unittest.TestCase):
|
|||||||
assert response.usage.total_tokens > 0
|
assert response.usage.total_tokens > 0
|
||||||
|
|
||||||
def run_completion_stream(self, echo, logprobs):
|
def run_completion_stream(self, echo, logprobs):
|
||||||
client = openai.Client(api_key="EMPTY", base_url=self.base_url)
|
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||||
prompt = "The capital of France is"
|
prompt = "The capital of France is"
|
||||||
generator = client.completions.create(
|
generator = client.completions.create(
|
||||||
model=self.model,
|
model=self.model,
|
||||||
@@ -102,7 +105,7 @@ class TestOpenAIServer(unittest.TestCase):
|
|||||||
assert response.usage.total_tokens > 0
|
assert response.usage.total_tokens > 0
|
||||||
|
|
||||||
def run_chat_completion(self, logprobs):
|
def run_chat_completion(self, logprobs):
|
||||||
client = openai.Client(api_key="EMPTY", base_url=self.base_url)
|
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||||
response = client.chat.completions.create(
|
response = client.chat.completions.create(
|
||||||
model=self.model,
|
model=self.model,
|
||||||
messages=[
|
messages=[
|
||||||
@@ -135,7 +138,7 @@ class TestOpenAIServer(unittest.TestCase):
|
|||||||
assert response.usage.total_tokens > 0
|
assert response.usage.total_tokens > 0
|
||||||
|
|
||||||
def run_chat_completion_stream(self, logprobs):
|
def run_chat_completion_stream(self, logprobs):
|
||||||
client = openai.Client(api_key="EMPTY", base_url=self.base_url)
|
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||||
generator = client.chat.completions.create(
|
generator = client.chat.completions.create(
|
||||||
model=self.model,
|
model=self.model,
|
||||||
messages=[
|
messages=[
|
||||||
@@ -186,7 +189,7 @@ class TestOpenAIServer(unittest.TestCase):
|
|||||||
self.run_chat_completion_stream(logprobs)
|
self.run_chat_completion_stream(logprobs)
|
||||||
|
|
||||||
def test_regex(self):
|
def test_regex(self):
|
||||||
client = openai.Client(api_key="EMPTY", base_url=self.base_url)
|
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||||
|
|
||||||
regex = (
|
regex = (
|
||||||
r"""\{\n"""
|
r"""\{\n"""
|
||||||
|
|||||||
Reference in New Issue
Block a user