Make API Key OpenAI-compatible (#917)

This commit is contained in:
Ying Sheng
2024-08-04 13:35:44 -07:00
committed by GitHub
parent afd411d09f
commit 0d4f3a9fcd
7 changed files with 115 additions and 125 deletions

View File

@@ -67,13 +67,13 @@ from sglang.srt.openai_api.adapter import (
from sglang.srt.openai_api.protocol import ModelCard, ModelList
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import (
API_KEY_HEADER_NAME,
APIKeyValidatorMiddleware,
add_api_key_middleware,
allocate_init_ports,
assert_pkg_version,
enable_show_time_cost,
kill_child_process,
maybe_set_triton_cache_manager,
set_torch_compile_config,
set_ulimit,
)
from sglang.utils import get_exception_traceback
@@ -158,6 +158,16 @@ async def openai_v1_chat_completions(raw_request: Request):
return await v1_chat_completions(tokenizer_manager, raw_request)
@app.get("/v1/models")
def available_models():
"""Show available models."""
served_model_names = [tokenizer_manager.served_model_name]
model_cards = []
for served_model_name in served_model_names:
model_cards.append(ModelCard(id=served_model_name, root=served_model_name))
return ModelList(data=model_cards)
@app.post("/v1/files")
async def openai_v1_files(file: UploadFile = File(...), purpose: str = Form("batch")):
return await v1_files_create(
@@ -187,69 +197,11 @@ async def retrieve_file_content(file_id: str):
return await v1_retrieve_file_content(file_id)
@app.get("/v1/models")
def available_models():
"""Show available models."""
served_model_names = [tokenizer_manager.served_model_name]
model_cards = []
for served_model_name in served_model_names:
model_cards.append(ModelCard(id=served_model_name, root=served_model_name))
return ModelList(data=model_cards)
def _set_torch_compile_config():
# The following configurations are for torch compile optimizations
import torch._dynamo.config
import torch._inductor.config
torch._inductor.config.coordinate_descent_tuning = True
torch._inductor.config.triton.unique_kernel_names = True
torch._inductor.config.fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future
# FIXME: tmp workaround
torch._dynamo.config.accumulated_cache_size_limit = 256
def set_envs_and_config(server_args: ServerArgs):
# Set global environments
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
os.environ["NCCL_CUMEM_ENABLE"] = "0"
os.environ["NCCL_NVLS_ENABLE"] = "0"
os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
# Set ulimit
set_ulimit()
# Enable show time cost for debugging
if server_args.show_time_cost:
enable_show_time_cost()
# Disable disk cache
if server_args.disable_disk_cache:
disable_cache()
# Fix triton bugs
if server_args.tp_size * server_args.dp_size > 1:
# FIXME: remove this after https://github.com/triton-lang/triton/pull/4295 is used as a dependency.
maybe_set_triton_cache_manager()
# Set torch compile config
if server_args.enable_torch_compile:
_set_torch_compile_config()
# Set global chat template
if server_args.chat_template:
# TODO: replace this with huggingface transformers template
load_chat_template_for_openai_api(server_args.chat_template)
def launch_server(
server_args: ServerArgs,
model_overide_args: Optional[dict] = None,
pipe_finish_writer: Optional[mp.connection.Connection] = None,
):
server_args.check_server_args()
"""Launch an HTTP server."""
global tokenizer_manager
@@ -258,16 +210,8 @@ def launch_server(
format="%(message)s",
)
if not server_args.disable_flashinfer:
assert_pkg_version(
"flashinfer",
"0.1.3",
"Please uninstall the old version and "
"reinstall the latest version by following the instructions "
"at https://docs.flashinfer.ai/installation.html.",
)
set_envs_and_config(server_args)
server_args.check_server_args()
_set_envs_and_config(server_args)
# Allocate ports
server_args.port, server_args.additional_ports = allocate_init_ports(
@@ -284,7 +228,7 @@ def launch_server(
)
logger.info(f"{server_args=}")
# Handle multi-node tensor parallelism
# Launch processes for multi-node tensor parallelism
if server_args.nnodes > 1:
if server_args.node_rank != 0:
tp_size_local = server_args.tp_size // server_args.nnodes
@@ -349,8 +293,9 @@ def launch_server(
sys.exit(1)
assert proc_controller.is_alive() and proc_detoken.is_alive()
if server_args.api_key and server_args.api_key != "":
app.add_middleware(APIKeyValidatorMiddleware, api_key=server_args.api_key)
# Add api key authorization
if server_args.api_key:
add_api_key_middleware(app, server_args.api_key)
# Send a warmup request
t = threading.Thread(
@@ -372,15 +317,58 @@ def launch_server(
t.join()
def _set_envs_and_config(server_args: ServerArgs):
# Set global environments
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
os.environ["NCCL_CUMEM_ENABLE"] = "0"
os.environ["NCCL_NVLS_ENABLE"] = "0"
os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
# Set ulimit
set_ulimit()
# Enable show time cost for debugging
if server_args.show_time_cost:
enable_show_time_cost()
# Disable disk cache
if server_args.disable_disk_cache:
disable_cache()
# Fix triton bugs
if server_args.tp_size * server_args.dp_size > 1:
# FIXME: remove this after https://github.com/triton-lang/triton/pull/4295 is used as a dependency.
maybe_set_triton_cache_manager()
# Set torch compile config
if server_args.enable_torch_compile:
set_torch_compile_config()
# Set global chat template
if server_args.chat_template:
# TODO: replace this with huggingface transformers template
load_chat_template_for_openai_api(server_args.chat_template)
# Check flashinfer version
if not server_args.disable_flashinfer:
assert_pkg_version(
"flashinfer",
"0.1.3",
"Please uninstall the old version and "
"reinstall the latest version by following the instructions "
"at https://docs.flashinfer.ai/installation.html.",
)
def _wait_and_warmup(server_args, pipe_finish_writer):
headers = {}
url = server_args.url()
if server_args.api_key:
headers[API_KEY_HEADER_NAME] = server_args.api_key
headers["Authorization"] = f"Bearer {server_args.api_key}"
# Wait until the server is launched
for _ in range(120):
time.sleep(0.5)
time.sleep(1)
try:
requests.get(url + "/get_model_info", timeout=5, headers=headers)
break