From 0d4f3a9fcdea60ac327a6a5897a281a1d763c3ac Mon Sep 17 00:00:00 2001 From: Ying Sheng Date: Sun, 4 Aug 2024 13:35:44 -0700 Subject: [PATCH] Make API Key OpenAI-compatible (#917) --- .../sglang/lang/backend/runtime_endpoint.py | 11 -- python/sglang/srt/server.py | 138 ++++++++---------- python/sglang/srt/server_args.py | 4 +- python/sglang/srt/utils.py | 45 +++--- python/sglang/test/test_utils.py | 15 +- python/sglang/utils.py | 12 +- test/srt/test_openai_server.py | 15 +- 7 files changed, 115 insertions(+), 125 deletions(-) diff --git a/python/sglang/lang/backend/runtime_endpoint.py b/python/sglang/lang/backend/runtime_endpoint.py index 929b3b6ad..7736c9f65 100644 --- a/python/sglang/lang/backend/runtime_endpoint.py +++ b/python/sglang/lang/backend/runtime_endpoint.py @@ -15,7 +15,6 @@ class RuntimeEndpoint(BaseBackend): def __init__( self, base_url: str, - auth_token: Optional[str] = None, api_key: Optional[str] = None, verify: Optional[str] = None, ): @@ -23,13 +22,11 @@ class RuntimeEndpoint(BaseBackend): self.support_concate_and_append = True self.base_url = base_url - self.auth_token = auth_token self.api_key = api_key self.verify = verify res = http_request( self.base_url + "/get_model_info", - auth_token=self.auth_token, api_key=self.api_key, verify=self.verify, ) @@ -67,7 +64,6 @@ class RuntimeEndpoint(BaseBackend): res = http_request( self.base_url + "/generate", json={"text": prefix_str, "sampling_params": {"max_new_tokens": 0}}, - auth_token=self.auth_token, api_key=self.api_key, verify=self.verify, ) @@ -79,7 +75,6 @@ class RuntimeEndpoint(BaseBackend): res = http_request( self.base_url + "/generate", json=data, - auth_token=self.auth_token, api_key=self.api_key, verify=self.verify, ) @@ -91,7 +86,6 @@ class RuntimeEndpoint(BaseBackend): res = http_request( self.base_url + "/generate", json=data, - auth_token=self.auth_token, api_key=self.api_key, verify=self.verify, ) @@ -139,7 +133,6 @@ class RuntimeEndpoint(BaseBackend): res = http_request( self.base_url + "/generate", json=data, - auth_token=self.auth_token, api_key=self.api_key, verify=self.verify, ) @@ -193,7 +186,6 @@ class RuntimeEndpoint(BaseBackend): self.base_url + "/generate", json=data, stream=True, - auth_token=self.auth_token, api_key=self.api_key, verify=self.verify, ) @@ -225,7 +217,6 @@ class RuntimeEndpoint(BaseBackend): res = http_request( self.base_url + "/generate", json=data, - auth_token=self.auth_token, api_key=self.api_key, verify=self.verify, ) @@ -243,7 +234,6 @@ class RuntimeEndpoint(BaseBackend): res = http_request( self.base_url + "/generate", json=data, - auth_token=self.auth_token, api_key=self.api_key, verify=self.verify, ) @@ -267,7 +257,6 @@ class RuntimeEndpoint(BaseBackend): res = http_request( self.base_url + "/concate_and_append_request", json={"src_rids": src_rids, "dst_rid": dst_rid}, - auth_token=self.auth_token, api_key=self.api_key, verify=self.verify, ) diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py index 18ff22432..4df1431cb 100644 --- a/python/sglang/srt/server.py +++ b/python/sglang/srt/server.py @@ -67,13 +67,13 @@ from sglang.srt.openai_api.adapter import ( from sglang.srt.openai_api.protocol import ModelCard, ModelList from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.utils import ( - API_KEY_HEADER_NAME, - APIKeyValidatorMiddleware, + add_api_key_middleware, allocate_init_ports, assert_pkg_version, enable_show_time_cost, kill_child_process, maybe_set_triton_cache_manager, + set_torch_compile_config, set_ulimit, ) from sglang.utils import get_exception_traceback @@ -158,6 +158,16 @@ async def openai_v1_chat_completions(raw_request: Request): return await v1_chat_completions(tokenizer_manager, raw_request) +@app.get("/v1/models") +def available_models(): + """Show available models.""" + served_model_names = [tokenizer_manager.served_model_name] + model_cards = [] + for served_model_name in served_model_names: + model_cards.append(ModelCard(id=served_model_name, root=served_model_name)) + return ModelList(data=model_cards) + + @app.post("/v1/files") async def openai_v1_files(file: UploadFile = File(...), purpose: str = Form("batch")): return await v1_files_create( @@ -187,69 +197,11 @@ async def retrieve_file_content(file_id: str): return await v1_retrieve_file_content(file_id) -@app.get("/v1/models") -def available_models(): - """Show available models.""" - served_model_names = [tokenizer_manager.served_model_name] - model_cards = [] - for served_model_name in served_model_names: - model_cards.append(ModelCard(id=served_model_name, root=served_model_name)) - return ModelList(data=model_cards) - - -def _set_torch_compile_config(): - # The following configurations are for torch compile optimizations - import torch._dynamo.config - import torch._inductor.config - - torch._inductor.config.coordinate_descent_tuning = True - torch._inductor.config.triton.unique_kernel_names = True - torch._inductor.config.fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future - - # FIXME: tmp workaround - torch._dynamo.config.accumulated_cache_size_limit = 256 - - -def set_envs_and_config(server_args: ServerArgs): - # Set global environments - os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" - os.environ["NCCL_CUMEM_ENABLE"] = "0" - os.environ["NCCL_NVLS_ENABLE"] = "0" - os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1" - - # Set ulimit - set_ulimit() - - # Enable show time cost for debugging - if server_args.show_time_cost: - enable_show_time_cost() - - # Disable disk cache - if server_args.disable_disk_cache: - disable_cache() - - # Fix triton bugs - if server_args.tp_size * server_args.dp_size > 1: - # FIXME: remove this after https://github.com/triton-lang/triton/pull/4295 is used as a dependency. - maybe_set_triton_cache_manager() - - # Set torch compile config - if server_args.enable_torch_compile: - _set_torch_compile_config() - - # Set global chat template - if server_args.chat_template: - # TODO: replace this with huggingface transformers template - load_chat_template_for_openai_api(server_args.chat_template) - - def launch_server( server_args: ServerArgs, model_overide_args: Optional[dict] = None, pipe_finish_writer: Optional[mp.connection.Connection] = None, ): - server_args.check_server_args() - """Launch an HTTP server.""" global tokenizer_manager @@ -258,16 +210,8 @@ def launch_server( format="%(message)s", ) - if not server_args.disable_flashinfer: - assert_pkg_version( - "flashinfer", - "0.1.3", - "Please uninstall the old version and " - "reinstall the latest version by following the instructions " - "at https://docs.flashinfer.ai/installation.html.", - ) - - set_envs_and_config(server_args) + server_args.check_server_args() + _set_envs_and_config(server_args) # Allocate ports server_args.port, server_args.additional_ports = allocate_init_ports( @@ -284,7 +228,7 @@ def launch_server( ) logger.info(f"{server_args=}") - # Handle multi-node tensor parallelism + # Launch processes for multi-node tensor parallelism if server_args.nnodes > 1: if server_args.node_rank != 0: tp_size_local = server_args.tp_size // server_args.nnodes @@ -349,8 +293,9 @@ def launch_server( sys.exit(1) assert proc_controller.is_alive() and proc_detoken.is_alive() - if server_args.api_key and server_args.api_key != "": - app.add_middleware(APIKeyValidatorMiddleware, api_key=server_args.api_key) + # Add api key authorization + if server_args.api_key: + add_api_key_middleware(app, server_args.api_key) # Send a warmup request t = threading.Thread( @@ -372,15 +317,58 @@ def launch_server( t.join() +def _set_envs_and_config(server_args: ServerArgs): + # Set global environments + os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" + os.environ["NCCL_CUMEM_ENABLE"] = "0" + os.environ["NCCL_NVLS_ENABLE"] = "0" + os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1" + + # Set ulimit + set_ulimit() + + # Enable show time cost for debugging + if server_args.show_time_cost: + enable_show_time_cost() + + # Disable disk cache + if server_args.disable_disk_cache: + disable_cache() + + # Fix triton bugs + if server_args.tp_size * server_args.dp_size > 1: + # FIXME: remove this after https://github.com/triton-lang/triton/pull/4295 is used as a dependency. + maybe_set_triton_cache_manager() + + # Set torch compile config + if server_args.enable_torch_compile: + set_torch_compile_config() + + # Set global chat template + if server_args.chat_template: + # TODO: replace this with huggingface transformers template + load_chat_template_for_openai_api(server_args.chat_template) + + # Check flashinfer version + if not server_args.disable_flashinfer: + assert_pkg_version( + "flashinfer", + "0.1.3", + "Please uninstall the old version and " + "reinstall the latest version by following the instructions " + "at https://docs.flashinfer.ai/installation.html.", + ) + + def _wait_and_warmup(server_args, pipe_finish_writer): headers = {} url = server_args.url() if server_args.api_key: - headers[API_KEY_HEADER_NAME] = server_args.api_key + headers["Authorization"] = f"Bearer {server_args.api_key}" # Wait until the server is launched for _ in range(120): - time.sleep(0.5) + time.sleep(1) try: requests.get(url + "/get_model_info", timeout=5, headers=headers) break diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 53aaca977..32e13658e 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -61,7 +61,7 @@ class ServerArgs: show_time_cost: bool = False # Other - api_key: str = "" + api_key: Optional[str] = None file_storage_pth: str = "SGlang_storage" # Data parallelism @@ -307,7 +307,7 @@ class ServerArgs: "--api-key", type=str, default=ServerArgs.api_key, - help="Set API key of the server.", + help="Set API key of the server. It is also used in the OpenAI API compatible server.", ) parser.add_argument( "--file-storage-pth", diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index fac8bdaa1..172c93c74 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -539,26 +539,6 @@ class CustomCacheManager(FileCacheManager): raise RuntimeError("Could not create or locate cache dir") -API_KEY_HEADER_NAME = "X-API-Key" - - -class APIKeyValidatorMiddleware(BaseHTTPMiddleware): - def __init__(self, app, api_key: str): - super().__init__(app) - self.api_key = api_key - - async def dispatch(self, request, call_next): - # extract API key from the request headers - api_key_header = request.headers.get(API_KEY_HEADER_NAME) - if not api_key_header or api_key_header != self.api_key: - return JSONResponse( - status_code=403, - content={"detail": "Invalid API Key"}, - ) - response = await call_next(request) - return response - - def get_ip_address(ifname): """ Get the IP address of a network interface. @@ -642,6 +622,19 @@ def receive_addrs(model_port_args, server_args): dist.destroy_process_group() +def set_torch_compile_config(): + # The following configurations are for torch compile optimizations + import torch._dynamo.config + import torch._inductor.config + + torch._inductor.config.coordinate_descent_tuning = True + torch._inductor.config.triton.unique_kernel_names = True + torch._inductor.config.fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future + + # FIXME: tmp workaround + torch._dynamo.config.accumulated_cache_size_limit = 256 + + def set_ulimit(target_soft_limit=65535): resource_type = resource.RLIMIT_NOFILE current_soft, current_hard = resource.getrlimit(resource_type) @@ -700,3 +693,15 @@ def monkey_patch_vllm_qvk_linear_loader(): origin_weight_loader(self, param, loaded_weight, loaded_shard_id) setattr(QKVParallelLinear, "weight_loader", weight_loader_srt) + + +def add_api_key_middleware(app, api_key): + @app.middleware("http") + async def authentication(request, call_next): + if request.method == "OPTIONS": + return await call_next(request) + if request.url.path.startswith("/health"): + return await call_next(request) + if request.headers.get("Authorization") != "Bearer " + api_key: + return JSONResponse(content={"error": "Unauthorized"}, status_code=401) + return await call_next(request) diff --git a/python/sglang/test/test_utils.py b/python/sglang/test/test_utils.py index 2ab009eba..1fe237a2f 100644 --- a/python/sglang/test/test_utils.py +++ b/python/sglang/test/test_utils.py @@ -391,7 +391,11 @@ def get_call_select(args: argparse.Namespace): def popen_launch_server( - model: str, base_url: str, timeout: float, other_args: tuple = () + model: str, + base_url: str, + timeout: float, + api_key: Optional[str] = None, + other_args: tuple = (), ): _, host, port = base_url.split(":") host = host[2:] @@ -408,12 +412,19 @@ def popen_launch_server( port, *other_args, ] + if api_key: + command += ["--api-key", api_key] + process = subprocess.Popen(command, stdout=None, stderr=None) start_time = time.time() while time.time() - start_time < timeout: try: - response = requests.get(f"{base_url}/v1/models") + headers = { + "Content-Type": "application/json; charset=utf-8", + "Authorization": f"Bearer {api_key}", + } + response = requests.get(f"{base_url}/v1/models", headers=headers) if response.status_code == 200: return process except requests.RequestException: diff --git a/python/sglang/utils.py b/python/sglang/utils.py index 27a8c40b8..c1193df3c 100644 --- a/python/sglang/utils.py +++ b/python/sglang/utils.py @@ -76,19 +76,13 @@ class HttpResponse: return self.resp.status -def http_request( - url, json=None, stream=False, auth_token=None, api_key=None, verify=None -): +def http_request(url, json=None, stream=False, api_key=None, verify=None): """A faster version of requests.post with low-level urllib API.""" headers = {"Content-Type": "application/json; charset=utf-8"} - # add the Authorization header if an auth token is provided - if auth_token is not None: - headers["Authorization"] = f"Bearer {auth_token}" - - # add the API Key header if an API key is provided + # add the Authorization header if an api key is provided if api_key is not None: - headers["X-API-Key"] = api_key + headers["Authorization"] = f"Bearer {api_key}" if stream: return requests.post(url, json=json, stream=True, headers=headers) diff --git a/test/srt/test_openai_server.py b/test/srt/test_openai_server.py index a2b934b6b..269664e14 100644 --- a/test/srt/test_openai_server.py +++ b/test/srt/test_openai_server.py @@ -13,7 +13,10 @@ class TestOpenAIServer(unittest.TestCase): def setUpClass(cls): cls.model = MODEL_NAME_FOR_TEST cls.base_url = f"http://localhost:30000" - cls.process = popen_launch_server(cls.model, cls.base_url, timeout=300) + cls.api_key = "sk-123456" + cls.process = popen_launch_server( + cls.model, cls.base_url, timeout=300, api_key=cls.api_key + ) cls.base_url += "/v1" @classmethod @@ -21,7 +24,7 @@ class TestOpenAIServer(unittest.TestCase): kill_child_process(cls.process.pid) def run_completion(self, echo, logprobs, use_list_input): - client = openai.Client(api_key="EMPTY", base_url=self.base_url) + client = openai.Client(api_key=self.api_key, base_url=self.base_url) prompt = "The capital of France is" if use_list_input: @@ -63,7 +66,7 @@ class TestOpenAIServer(unittest.TestCase): assert response.usage.total_tokens > 0 def run_completion_stream(self, echo, logprobs): - client = openai.Client(api_key="EMPTY", base_url=self.base_url) + client = openai.Client(api_key=self.api_key, base_url=self.base_url) prompt = "The capital of France is" generator = client.completions.create( model=self.model, @@ -102,7 +105,7 @@ class TestOpenAIServer(unittest.TestCase): assert response.usage.total_tokens > 0 def run_chat_completion(self, logprobs): - client = openai.Client(api_key="EMPTY", base_url=self.base_url) + client = openai.Client(api_key=self.api_key, base_url=self.base_url) response = client.chat.completions.create( model=self.model, messages=[ @@ -135,7 +138,7 @@ class TestOpenAIServer(unittest.TestCase): assert response.usage.total_tokens > 0 def run_chat_completion_stream(self, logprobs): - client = openai.Client(api_key="EMPTY", base_url=self.base_url) + client = openai.Client(api_key=self.api_key, base_url=self.base_url) generator = client.chat.completions.create( model=self.model, messages=[ @@ -186,7 +189,7 @@ class TestOpenAIServer(unittest.TestCase): self.run_chat_completion_stream(logprobs) def test_regex(self): - client = openai.Client(api_key="EMPTY", base_url=self.base_url) + client = openai.Client(api_key=self.api_key, base_url=self.base_url) regex = ( r"""\{\n"""