diff --git a/python/sglang/launch_server.py b/python/sglang/launch_server.py index 9d63a2aed..40db38b9a 100644 --- a/python/sglang/launch_server.py +++ b/python/sglang/launch_server.py @@ -2,10 +2,11 @@ import argparse from sglang.srt.server import ServerArgs, launch_server + if __name__ == "__main__": parser = argparse.ArgumentParser() ServerArgs.add_cli_args(parser) args = parser.parse_args() server_args = ServerArgs.from_cli_args(args) - launch_server(server_args, None) + launch_server(server_args, None) \ No newline at end of file diff --git a/python/sglang/srt/managers/router/model_rpc.py b/python/sglang/srt/managers/router/model_rpc.py index 55bd9e80c..f837d9029 100644 --- a/python/sglang/srt/managers/router/model_rpc.py +++ b/python/sglang/srt/managers/router/model_rpc.py @@ -37,6 +37,7 @@ from sglang.srt.utils import ( ) logger = logging.getLogger("model_rpc") +logging.getLogger("vllm.utils").setLevel(logging.WARN) class ModelRpcServer: @@ -113,7 +114,7 @@ class ModelRpcServer: f"max_prefill_num_token={self.max_prefill_num_token}, " f"context_len={self.model_config.context_len}, " ) - logger.info(server_args.get_optional_modes_logging()) + logger.info(f"server_args: {server_args.print_mode_args()}") # Init cache self.tree_cache = RadixCache(disable=server_args.disable_radix_cache) diff --git a/python/sglang/srt/managers/router/model_runner.py b/python/sglang/srt/managers/router/model_runner.py index b2a0daf5b..8d6851caf 100644 --- a/python/sglang/srt/managers/router/model_runner.py +++ b/python/sglang/srt/managers/router/model_runner.py @@ -28,7 +28,6 @@ QUANTIZATION_CONFIG_MAPPING = { logger = logging.getLogger("model_runner") - # for server args in model endpoints global_server_args_dict: dict = None @@ -276,9 +275,6 @@ class ModelRunner: init_method=f"tcp://127.0.0.1:{self.nccl_port}", ) - # A small all_reduce for warmup. - if self.tp_size > 1: - torch.distributed.all_reduce(torch.zeros(1).cuda()) initialize_model_parallel(tensor_model_parallel_size=self.tp_size) total_gpu_memory = get_available_gpu_memory( diff --git a/python/sglang/srt/managers/openai_protocol.py b/python/sglang/srt/openai_protocol.py similarity index 100% rename from python/sglang/srt/managers/openai_protocol.py rename to python/sglang/srt/openai_protocol.py diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py index 4ad5701c5..c1b7780cc 100644 --- a/python/sglang/srt/server.py +++ b/python/sglang/srt/server.py @@ -15,15 +15,11 @@ setattr(threading, "_register_atexit", lambda *args, **kwargs: None) import aiohttp import psutil -import pydantic import requests import uvicorn import uvloop from fastapi import FastAPI, HTTPException, Request from fastapi.responses import Response, StreamingResponse -from pydantic import BaseModel -from starlette.middleware.base import BaseHTTPMiddleware -from starlette.responses import JSONResponse from sglang.backend.runtime_endpoint import RuntimeEndpoint from sglang.srt.constrained import disable_cache @@ -37,7 +33,7 @@ from sglang.srt.conversation import ( from sglang.srt.hf_transformers_utils import get_tokenizer from sglang.srt.managers.detokenizer_manager import start_detokenizer_process from sglang.srt.managers.io_struct import DetokenizeReqInput, GenerateReqInput -from sglang.srt.managers.openai_protocol import ( +from sglang.srt.openai_protocol import ( ChatCompletionRequest, ChatCompletionResponse, ChatCompletionResponseChoice, @@ -56,45 +52,24 @@ from sglang.srt.managers.openai_protocol import ( from sglang.srt.managers.router.manager import start_router_process from sglang.srt.managers.tokenizer_manager import TokenizerManager from sglang.srt.server_args import PortArgs, ServerArgs -from sglang.srt.utils import enable_show_time_cost, handle_port_init +from sglang.srt.utils import ( + enable_show_time_cost, + allocate_init_ports, + jsonify_pydantic_model, + assert_pkg_version, + get_exception_traceback, + API_KEY_HEADER_NAME, + APIKeyValidatorMiddleware +) asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) -API_KEY_HEADER_NAME = "X-API-Key" - - -class APIKeyValidatorMiddleware(BaseHTTPMiddleware): - def __init__(self, app, api_key: str): - super().__init__(app) - self.api_key = api_key - - async def dispatch(self, request: Request, call_next): - # extract API key from the request headers - api_key_header = request.headers.get(API_KEY_HEADER_NAME) - if not api_key_header or api_key_header != self.api_key: - return JSONResponse( - status_code=403, - content={"detail": "Invalid API Key"}, - ) - response = await call_next(request) - return response - app = FastAPI() tokenizer_manager = None chat_template_name = None -# FIXME: Remove this once we drop support for pydantic 1.x -IS_PYDANTIC_1 = int(pydantic.VERSION.split(".")[0]) == 1 - - -def jsonify_pydantic_model(obj: BaseModel): - if IS_PYDANTIC_1: - return obj.json(ensure_ascii=False) - return obj.model_dump_json() - - @app.get("/health") async def health() -> Response: """Health check.""" @@ -124,6 +99,31 @@ async def flush_cache(): ) +async def stream_generator(obj: GenerateReqInput): + async for out in tokenizer_manager.generate_request(obj): + await handle_token_logprobs_results(obj, out) + yield out + + +@app.post("/generate") +async def generate_request(obj: GenerateReqInput): + obj.post_init() + + if obj.stream: + + async def stream_results(): + async for out in stream_generator(obj): + yield f"data: {json.dumps(out, ensure_ascii=False)}\n\n" + yield "data: [DONE]\n\n" + + return StreamingResponse(stream_results(), media_type="text/event-stream") + + ret = await tokenizer_manager.generate_request(obj).__anext__() + await handle_token_logprobs_results(obj, ret) + + return ret + + async def detokenize_logprob_tokens(token_logprobs, decode_to_text): if not decode_to_text: return [(logprob, token_id, None) for logprob, token_id in token_logprobs] @@ -175,68 +175,6 @@ async def handle_token_logprobs_results(obj: GenerateReqInput, ret): await convert_style(r, obj.return_text_in_logprobs) -async def stream_generator(obj: GenerateReqInput): - async for out in tokenizer_manager.generate_request(obj): - await handle_token_logprobs_results(obj, out) - yield out - - -async def make_openai_style_logprobs( - prefill_token_logprobs=None, - decode_token_logprobs=None, - prefill_top_logprobs=None, - decode_top_logprobs=None, -): - ret_logprobs = LogProbs() - - def append_token_logprobs(token_logprobs): - for logprob, _, token_text in token_logprobs: - ret_logprobs.tokens.append(token_text) - ret_logprobs.token_logprobs.append(logprob) - - # Not Supported yet - ret_logprobs.text_offset.append(-1) - - def append_top_logprobs(top_logprobs): - for tokens in top_logprobs: - if tokens is not None: - ret_logprobs.top_logprobs.append( - {token[2]: token[0] for token in tokens} - ) - else: - ret_logprobs.top_logprobs.append(None) - - if prefill_token_logprobs is not None: - append_token_logprobs(prefill_token_logprobs) - if decode_token_logprobs is not None: - append_token_logprobs(decode_token_logprobs) - if prefill_top_logprobs is not None: - append_top_logprobs(prefill_top_logprobs) - if decode_top_logprobs is not None: - append_top_logprobs(decode_top_logprobs) - - return ret_logprobs - - -@app.post("/generate") -async def generate_request(obj: GenerateReqInput): - obj.post_init() - - if obj.stream: - - async def stream_results(): - async for out in stream_generator(obj): - yield f"data: {json.dumps(out, ensure_ascii=False)}\n\n" - yield "data: [DONE]\n\n" - - return StreamingResponse(stream_results(), media_type="text/event-stream") - - ret = await tokenizer_manager.generate_request(obj).__anext__() - await handle_token_logprobs_results(obj, ret) - - return ret - - @app.post("/v1/completions") async def v1_completions(raw_request: Request): request_json = await raw_request.json() @@ -500,27 +438,97 @@ async def v1_chat_completions(raw_request: Request): return response -def launch_server(server_args: ServerArgs, pipe_finish_writer): - global tokenizer_manager +async def make_openai_style_logprobs( + prefill_token_logprobs=None, + decode_token_logprobs=None, + prefill_top_logprobs=None, + decode_top_logprobs=None, +): + ret_logprobs = LogProbs() + + def append_token_logprobs(token_logprobs): + for logprob, _, token_text in token_logprobs: + ret_logprobs.tokens.append(token_text) + ret_logprobs.token_logprobs.append(logprob) + + # Not Supported yet + ret_logprobs.text_offset.append(-1) + + def append_top_logprobs(top_logprobs): + for tokens in top_logprobs: + if tokens is not None: + ret_logprobs.top_logprobs.append( + {token[2]: token[0] for token in tokens} + ) + else: + ret_logprobs.top_logprobs.append(None) + + if prefill_token_logprobs is not None: + append_token_logprobs(prefill_token_logprobs) + if decode_token_logprobs is not None: + append_token_logprobs(decode_token_logprobs) + if prefill_top_logprobs is not None: + append_top_logprobs(prefill_top_logprobs) + if decode_top_logprobs is not None: + append_top_logprobs(decode_top_logprobs) + + return ret_logprobs + + +def load_chat_template_for_openai_api(chat_template_arg): global chat_template_name - if server_args.enable_flashinfer: - from sglang.srt.utils import assert_pkg_version - assert_pkg_version("flashinfer", "0.0.4") + print(f"Use chat template: {chat_template_arg}") + if not chat_template_exists(chat_template_arg): + if not os.path.exists(chat_template_arg): + raise RuntimeError( + f"Chat template {chat_template_arg} is not a built-in template name " + "or a valid chat template file path." + ) + with open(chat_template_arg, "r") as filep: + template = json.load(filep) + try: + sep_style = SeparatorStyle[template["sep_style"]] + except KeyError: + raise ValueError( + f"Unknown separator style: {template['sep_style']}" + ) from None + register_conv_template( + Conversation( + name=template["name"], + system_template=template["system"] + "\n{system_message}", + system_message=template.get("system_message", ""), + roles=(template["user"], template["assistant"]), + sep_style=sep_style, + sep=template.get("sep", "\n"), + stop_str=template["stop_str"], + ), + override=True, + ) + chat_template_name = template["name"] + else: + chat_template_name = chat_template_arg - # start show time thread + +def launch_server(server_args: ServerArgs, pipe_finish_writer): + global tokenizer_manager + + # Set global environments + os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" if server_args.show_time_cost: enable_show_time_cost() - - # disable disk cache if needed if server_args.disable_disk_cache: disable_cache() + if server_args.enable_flashinfer: + assert_pkg_version("flashinfer", "0.0.4") + if server_args.chat_template: + # TODO: replace this with huggingface transformers template + load_chat_template_for_openai_api(server_args.chat_template) - # Handle ports - server_args.port, server_args.additional_ports = handle_port_init( + # Allocate ports + server_args.port, server_args.additional_ports = allocate_init_ports( server_args.port, server_args.additional_ports, server_args.tp_size ) - port_args = PortArgs( tokenizer_port=server_args.additional_ports[0], router_port=server_args.additional_ports[1], @@ -529,39 +537,6 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer): model_rpc_ports=server_args.additional_ports[4:], ) - # Load chat template if needed - if server_args.chat_template is not None: - print(f"Use chat template: {server_args.chat_template}") - if not chat_template_exists(server_args.chat_template): - if not os.path.exists(server_args.chat_template): - raise RuntimeError( - f"Chat template {server_args.chat_template} is not a built-in template name " - "or a valid chat template file path." - ) - with open(server_args.chat_template, "r") as filep: - template = json.load(filep) - try: - sep_style = SeparatorStyle[template["sep_style"]] - except KeyError: - raise ValueError( - f"Unknown separator style: {template['sep_style']}" - ) from None - register_conv_template( - Conversation( - name=template["name"], - system_template=template["system"] + "\n{system_message}", - system_message=template.get("system_message", ""), - roles=(template["user"], template["assistant"]), - sep_style=sep_style, - sep=template.get("sep", "\n"), - stop_str=template["stop_str"], - ), - override=True, - ) - chat_template_name = template["name"] - else: - chat_template_name = server_args.chat_template - # Launch processes tokenizer_manager = TokenizerManager(server_args, port_args) pipe_router_reader, pipe_router_writer = mp.Pipe(duplex=False) @@ -593,31 +568,21 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer): if router_init_state != "init ok" or detoken_init_state != "init ok": proc_router.kill() proc_detoken.kill() - print("router init state:", router_init_state) - print("detoken init state:", detoken_init_state) + print(f"Initialization failed. router_init_state: {router_init_state}", flush=True) + print(f"Initialization failed. detoken_init_state: {detoken_init_state}", flush=True) sys.exit(1) - assert proc_router.is_alive() and proc_detoken.is_alive() if server_args.api_key and server_args.api_key != "": app.add_middleware(APIKeyValidatorMiddleware, api_key=server_args.api_key) - def _launch_server(): - uvicorn.run( - app, - host=server_args.host, - port=server_args.port, - log_level=server_args.log_level, - timeout_keep_alive=5, - loop="uvloop", - ) - def _wait_and_warmup(): headers = {} url = server_args.url() - if server_args.api_key and server_args.api_key != "": + if server_args.api_key: headers[API_KEY_HEADER_NAME] = server_args.api_key + # Wait until the server is launched for _ in range(120): time.sleep(0.5) try: @@ -625,16 +590,9 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer): break except requests.exceptions.RequestException as e: pass - else: - if pipe_finish_writer is not None: - pipe_finish_writer.send(str(e)) - else: - print(e, flush=True) - return - # Warmup + # Send a warmup request try: - # print("Warmup...", flush=True) res = requests.post( url + "/generate", json={ @@ -647,14 +605,12 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer): headers=headers, timeout=60, ) - # print(f"Warmup done. model response: {res.json()['text']}") - # print("=" * 20, "Server is ready", "=" * 20, flush=True) - except requests.exceptions.RequestException as e: + assert res.status_code == 200 + except Exception as e: if pipe_finish_writer is not None: - pipe_finish_writer.send(str(e)) - else: - print(e, flush=True) - return + pipe_finish_writer.send(get_exception_traceback()) + print(f"Initialization failed. warmup error: {e}") + raise e if pipe_finish_writer is not None: pipe_finish_writer.send("init ok") @@ -662,7 +618,14 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer): t = threading.Thread(target=_wait_and_warmup) t.start() try: - _launch_server() + uvicorn.run( + app, + host=server_args.host, + port=server_args.port, + log_level=server_args.log_level, + timeout_keep_alive=5, + loop="uvloop", + ) finally: t.join() @@ -670,52 +633,16 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer): class Runtime: def __init__( self, - model_path: str, - tokenizer_path: Optional[str] = None, - load_format: str = "auto", - tokenizer_mode: str = "auto", - trust_remote_code: bool = True, - mem_fraction_static: float = ServerArgs.mem_fraction_static, - max_prefill_num_token: int = ServerArgs.max_prefill_num_token, - context_length: int = ServerArgs.context_length, - tp_size: int = 1, - schedule_heuristic: str = "lpm", - attention_reduce_in_fp32: bool = False, - random_seed: int = 42, - log_level: str = "error", - disable_radix_cache: bool = False, - enable_flashinfer: bool = False, - disable_regex_jump_forward: bool = False, - disable_disk_cache: bool = False, - api_key: str = "", - port: Optional[int] = None, - additional_ports: Optional[Union[List[int], int]] = None, + log_evel="error", + *args, + **kwargs, ): - host = "127.0.0.1" - port, additional_ports = handle_port_init(port, additional_ports, tp_size) - self.server_args = ServerArgs( - model_path=model_path, - tokenizer_path=tokenizer_path, - host=host, - port=port, - additional_ports=additional_ports, - load_format=load_format, - tokenizer_mode=tokenizer_mode, - trust_remote_code=trust_remote_code, - mem_fraction_static=mem_fraction_static, - max_prefill_num_token=max_prefill_num_token, - context_length=context_length, - tp_size=tp_size, - schedule_heuristic=schedule_heuristic, - attention_reduce_in_fp32=attention_reduce_in_fp32, - random_seed=random_seed, - log_level=log_level, - disable_radix_cache=disable_radix_cache, - enable_flashinfer=enable_flashinfer, - disable_regex_jump_forward=disable_regex_jump_forward, - disable_disk_cache=disable_disk_cache, - api_key=api_key, - ) + """See the arguments in server_args.py::ServerArgs""" + self.server_args = ServerArgs(*args, log_level=log_evel, **kwargs) + + # Pre-allocate ports + self.server_args.port, self.server_args.additional_ports = allocate_init_ports( + self.server_args.port, self.server_args.additional_ports, self.server_args.tp_size) self.url = self.server_args.url() self.generate_url = ( @@ -736,7 +663,7 @@ class Runtime: if init_state != "init ok": self.shutdown() - raise RuntimeError("Launch failed. Please see the error messages above.") + raise RuntimeError("Initialization failed. Please see the error messages above.") self.endpoint = RuntimeEndpoint(self.url) @@ -765,13 +692,12 @@ class Runtime: self, prompt: str, sampling_params, - ) -> None: + ): json_data = { "text": prompt, "sampling_params": sampling_params, "stream": True, } - pos = 0 timeout = aiohttp.ClientTimeout(total=3 * 3600) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 4ed14a6be..78b8537c4 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -1,3 +1,5 @@ +"""The arguments of the server.""" + import argparse import dataclasses from typing import List, Optional, Union @@ -5,33 +7,44 @@ from typing import List, Optional, Union @dataclasses.dataclass class ServerArgs: + # Model and tokenizer model_path: str tokenizer_path: Optional[str] = None - host: str = "127.0.0.1" - port: int = 30000 - additional_ports: Optional[Union[List[int], int]] = None load_format: str = "auto" tokenizer_mode: str = "auto" chat_template: Optional[str] = None trust_remote_code: bool = True + context_length: Optional[int] = None + + # Port + host: str = "127.0.0.1" + port: int = 30000 + additional_ports: Optional[Union[List[int], int]] = None + + # Memory and scheduling mem_fraction_static: Optional[float] = None max_prefill_num_token: Optional[int] = None - context_length: Optional[int] = None - tp_size: int = 1 schedule_heuristic: str = "lpm" schedule_conservativeness: float = 1.0 - attention_reduce_in_fp32: bool = False - random_seed: int = 42 + + # Other runtime options + tp_size: int = 1 stream_interval: int = 8 + random_seed: int = 42 + + # Logging + log_level: str = "info" disable_log_stats: bool = False log_stats_interval: int = 10 - log_level: str = "info" - api_key: str = "" show_time_cost: bool = False - # optional modes - disable_radix_cache: bool = False + # Other + api_key: str = "" + + # Optimization/debug options enable_flashinfer: bool = False + attention_reduce_in_fp32: bool = False + disable_radix_cache: bool = False disable_regex_jump_forward: bool = False disable_disk_cache: bool = False @@ -66,15 +79,16 @@ class ServerArgs: default=ServerArgs.tokenizer_path, help="The path of the tokenizer.", ) - parser.add_argument("--host", type=str, default=ServerArgs.host) - parser.add_argument("--port", type=int, default=ServerArgs.port) - # we want to be able to pass a list of ports + parser.add_argument("--host", type=str, default=ServerArgs.host, + help="The host of the server.") + parser.add_argument("--port", type=int, default=ServerArgs.port, + help="The port of the server.") parser.add_argument( "--additional-ports", type=int, nargs="*", default=[], - help="Additional ports specified for launching server.", + help="Additional ports specified for the server.", ) parser.add_argument( "--load-format", @@ -112,6 +126,12 @@ class ServerArgs: action="store_true", help="Whether or not to allow for custom models defined on the Hub in their own modeling files.", ) + parser.add_argument( + "--context-length", + type=int, + default=ServerArgs.context_length, + help="The model's maximum context length. Defaults to None (will use the value from the model's config.json instead).", + ) parser.add_argument( "--mem-fraction-static", type=float, @@ -124,18 +144,6 @@ class ServerArgs: default=ServerArgs.max_prefill_num_token, help="The maximum number of tokens in a prefill batch. The real bound will be the maximum of this value and the model's maximum context length.", ) - parser.add_argument( - "--context-length", - type=int, - default=ServerArgs.context_length, - help="The model's maximum context length. Use this to reduce the context length to save memory. Defaults to None (will use the value from the model's config.json instead).", - ) - parser.add_argument( - "--tp-size", - type=int, - default=ServerArgs.tp_size, - help="Tensor parallelism degree.", - ) parser.add_argument( "--schedule-heuristic", type=str, @@ -149,15 +157,10 @@ class ServerArgs: help="How conservative the schedule policy is. A larger value means more conservative scheduling. Use a larger value if you see requests being retracted frequently.", ) parser.add_argument( - "--random-seed", + "--tp-size", type=int, - default=ServerArgs.random_seed, - help="Random seed.", - ) - parser.add_argument( - "--attention-reduce-in-fp32", - action="store_true", - help="Cast the intermidiate attention results to fp32 to avoid possible crashes related to fp16.", + default=ServerArgs.tp_size, + help="Tensor parallelism size.", ) parser.add_argument( "--stream-interval", @@ -165,11 +168,17 @@ class ServerArgs: default=ServerArgs.stream_interval, help="The interval (or buffer size) for streaming in terms of the token length. A smaller value makes streaming smoother, while a larger value makes the throughput higher", ) + parser.add_argument( + "--random-seed", + type=int, + default=ServerArgs.random_seed, + help="Random seed.", + ) parser.add_argument( "--log-level", type=str, default=ServerArgs.log_level, - help="Log level", + help="Logging level", ) parser.add_argument( "--disable-log-stats", @@ -182,29 +191,34 @@ class ServerArgs: default=ServerArgs.log_stats_interval, help="Log stats interval in second.", ) - parser.add_argument( - "--api-key", - type=str, - default=ServerArgs.api_key, - help="Set API Key", - ) parser.add_argument( "--show-time-cost", action="store_true", help="Show time cost of custom marks", ) - - # optional modes parser.add_argument( - "--disable-radix-cache", - action="store_true", - help="Disable RadixAttention", + "--api-key", + type=str, + default=ServerArgs.api_key, + help="Set API key of the server", ) + + # Optimization/debug options parser.add_argument( "--enable-flashinfer", action="store_true", help="Enable flashinfer inference kernels", ) + parser.add_argument( + "--attention-reduce-in-fp32", + action="store_true", + help="Cast the intermidiate attention results to fp32 to avoid possible crashes related to fp16.", + ) + parser.add_argument( + "--disable-radix-cache", + action="store_true", + help="Disable RadixAttention", + ) parser.add_argument( "--disable-regex-jump-forward", action="store_true", @@ -224,13 +238,13 @@ class ServerArgs: def url(self): return f"http://{self.host}:{self.port}" - def get_optional_modes_logging(self): + def print_mode_args(self): return ( - f"disable_radix_cache={self.disable_radix_cache}, " f"enable_flashinfer={self.enable_flashinfer}, " + f"attention_reduce_in_fp32={self.attention_reduce_in_fp32}" + f"disable_radix_cache={self.disable_radix_cache}, " f"disable_regex_jump_forward={self.disable_regex_jump_forward}, " f"disable_disk_cache={self.disable_disk_cache}, " - f"attention_reduce_in_fp32={self.attention_reduce_in_fp32}" ) @@ -240,4 +254,4 @@ class PortArgs: router_port: int detokenizer_port: int nccl_port: int - model_rpc_ports: List[int] + model_rpc_ports: List[int] \ No newline at end of file diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index 6b2c258d1..774bdf2c9 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -10,9 +10,12 @@ from io import BytesIO from typing import List, Optional import numpy as np +import pydantic import requests import torch from packaging import version as pkg_version +from pydantic import BaseModel +from starlette.middleware.base import BaseHTTPMiddleware show_time_cost = False time_infos = {} @@ -120,7 +123,7 @@ def check_port(port): return False -def handle_port_init( +def allocate_init_ports( port: Optional[int] = None, additional_ports: Optional[List[int]] = None, tp_size: int = 1, @@ -159,8 +162,6 @@ def get_exception_traceback(): def get_int_token_logit_bias(tokenizer, vocab_size): - from transformers import LlamaTokenizer, LlamaTokenizerFast - # a bug when model's vocab size > tokenizer.vocab_size vocab_size = tokenizer.vocab_size logit_bias = np.zeros(vocab_size, dtype=np.float32) @@ -281,3 +282,32 @@ def assert_pkg_version(pkg: str, min_version: str): ) except PackageNotFoundError: raise Exception(f"{pkg} with minimum required version {min_version} is not installed") + + +API_KEY_HEADER_NAME = "X-API-Key" + + +class APIKeyValidatorMiddleware(BaseHTTPMiddleware): + def __init__(self, app, api_key: str): + super().__init__(app) + self.api_key = api_key + + async def dispatch(self, request, call_next): + # extract API key from the request headers + api_key_header = request.headers.get(API_KEY_HEADER_NAME) + if not api_key_header or api_key_header != self.api_key: + return JSONResponse( + status_code=403, + content={"detail": "Invalid API Key"}, + ) + response = await call_next(request) + return response + +# FIXME: Remove this once we drop support for pydantic 1.x +IS_PYDANTIC_1 = int(pydantic.VERSION.split(".")[0]) == 1 + + +def jsonify_pydantic_model(obj: BaseModel): + if IS_PYDANTIC_1: + return obj.json(ensure_ascii=False) + return obj.model_dump_json() diff --git a/test/killall_python.sh b/test/killall_python.sh index ae9de8701..0e2cb82a8 100644 --- a/test/killall_python.sh +++ b/test/killall_python.sh @@ -1 +1 @@ -kill -9 $(ps aux | grep 'python' | grep -v 'grep' | awk '{print $2}') +kill -9 $(ps aux | grep 'sglang' | grep -v 'grep' | awk '{print $2}')