Clean up (#422)
This commit is contained in:
@@ -2,10 +2,11 @@ import argparse
|
||||
|
||||
from sglang.srt.server import ServerArgs, launch_server
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
ServerArgs.add_cli_args(parser)
|
||||
args = parser.parse_args()
|
||||
server_args = ServerArgs.from_cli_args(args)
|
||||
|
||||
launch_server(server_args, None)
|
||||
launch_server(server_args, None)
|
||||
@@ -37,6 +37,7 @@ from sglang.srt.utils import (
|
||||
)
|
||||
|
||||
logger = logging.getLogger("model_rpc")
|
||||
logging.getLogger("vllm.utils").setLevel(logging.WARN)
|
||||
|
||||
|
||||
class ModelRpcServer:
|
||||
@@ -113,7 +114,7 @@ class ModelRpcServer:
|
||||
f"max_prefill_num_token={self.max_prefill_num_token}, "
|
||||
f"context_len={self.model_config.context_len}, "
|
||||
)
|
||||
logger.info(server_args.get_optional_modes_logging())
|
||||
logger.info(f"server_args: {server_args.print_mode_args()}")
|
||||
|
||||
# Init cache
|
||||
self.tree_cache = RadixCache(disable=server_args.disable_radix_cache)
|
||||
|
||||
@@ -28,7 +28,6 @@ QUANTIZATION_CONFIG_MAPPING = {
|
||||
|
||||
logger = logging.getLogger("model_runner")
|
||||
|
||||
|
||||
# for server args in model endpoints
|
||||
global_server_args_dict: dict = None
|
||||
|
||||
@@ -276,9 +275,6 @@ class ModelRunner:
|
||||
init_method=f"tcp://127.0.0.1:{self.nccl_port}",
|
||||
)
|
||||
|
||||
# A small all_reduce for warmup.
|
||||
if self.tp_size > 1:
|
||||
torch.distributed.all_reduce(torch.zeros(1).cuda())
|
||||
initialize_model_parallel(tensor_model_parallel_size=self.tp_size)
|
||||
|
||||
total_gpu_memory = get_available_gpu_memory(
|
||||
|
||||
@@ -15,15 +15,11 @@ setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
|
||||
|
||||
import aiohttp
|
||||
import psutil
|
||||
import pydantic
|
||||
import requests
|
||||
import uvicorn
|
||||
import uvloop
|
||||
from fastapi import FastAPI, HTTPException, Request
|
||||
from fastapi.responses import Response, StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.responses import JSONResponse
|
||||
|
||||
from sglang.backend.runtime_endpoint import RuntimeEndpoint
|
||||
from sglang.srt.constrained import disable_cache
|
||||
@@ -37,7 +33,7 @@ from sglang.srt.conversation import (
|
||||
from sglang.srt.hf_transformers_utils import get_tokenizer
|
||||
from sglang.srt.managers.detokenizer_manager import start_detokenizer_process
|
||||
from sglang.srt.managers.io_struct import DetokenizeReqInput, GenerateReqInput
|
||||
from sglang.srt.managers.openai_protocol import (
|
||||
from sglang.srt.openai_protocol import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionResponseChoice,
|
||||
@@ -56,45 +52,24 @@ from sglang.srt.managers.openai_protocol import (
|
||||
from sglang.srt.managers.router.manager import start_router_process
|
||||
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||
from sglang.srt.utils import enable_show_time_cost, handle_port_init
|
||||
from sglang.srt.utils import (
|
||||
enable_show_time_cost,
|
||||
allocate_init_ports,
|
||||
jsonify_pydantic_model,
|
||||
assert_pkg_version,
|
||||
get_exception_traceback,
|
||||
API_KEY_HEADER_NAME,
|
||||
APIKeyValidatorMiddleware
|
||||
)
|
||||
|
||||
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
||||
|
||||
API_KEY_HEADER_NAME = "X-API-Key"
|
||||
|
||||
|
||||
class APIKeyValidatorMiddleware(BaseHTTPMiddleware):
|
||||
def __init__(self, app, api_key: str):
|
||||
super().__init__(app)
|
||||
self.api_key = api_key
|
||||
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
# extract API key from the request headers
|
||||
api_key_header = request.headers.get(API_KEY_HEADER_NAME)
|
||||
if not api_key_header or api_key_header != self.api_key:
|
||||
return JSONResponse(
|
||||
status_code=403,
|
||||
content={"detail": "Invalid API Key"},
|
||||
)
|
||||
response = await call_next(request)
|
||||
return response
|
||||
|
||||
|
||||
app = FastAPI()
|
||||
tokenizer_manager = None
|
||||
chat_template_name = None
|
||||
|
||||
|
||||
# FIXME: Remove this once we drop support for pydantic 1.x
|
||||
IS_PYDANTIC_1 = int(pydantic.VERSION.split(".")[0]) == 1
|
||||
|
||||
|
||||
def jsonify_pydantic_model(obj: BaseModel):
|
||||
if IS_PYDANTIC_1:
|
||||
return obj.json(ensure_ascii=False)
|
||||
return obj.model_dump_json()
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health() -> Response:
|
||||
"""Health check."""
|
||||
@@ -124,6 +99,31 @@ async def flush_cache():
|
||||
)
|
||||
|
||||
|
||||
async def stream_generator(obj: GenerateReqInput):
|
||||
async for out in tokenizer_manager.generate_request(obj):
|
||||
await handle_token_logprobs_results(obj, out)
|
||||
yield out
|
||||
|
||||
|
||||
@app.post("/generate")
|
||||
async def generate_request(obj: GenerateReqInput):
|
||||
obj.post_init()
|
||||
|
||||
if obj.stream:
|
||||
|
||||
async def stream_results():
|
||||
async for out in stream_generator(obj):
|
||||
yield f"data: {json.dumps(out, ensure_ascii=False)}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
return StreamingResponse(stream_results(), media_type="text/event-stream")
|
||||
|
||||
ret = await tokenizer_manager.generate_request(obj).__anext__()
|
||||
await handle_token_logprobs_results(obj, ret)
|
||||
|
||||
return ret
|
||||
|
||||
|
||||
async def detokenize_logprob_tokens(token_logprobs, decode_to_text):
|
||||
if not decode_to_text:
|
||||
return [(logprob, token_id, None) for logprob, token_id in token_logprobs]
|
||||
@@ -175,68 +175,6 @@ async def handle_token_logprobs_results(obj: GenerateReqInput, ret):
|
||||
await convert_style(r, obj.return_text_in_logprobs)
|
||||
|
||||
|
||||
async def stream_generator(obj: GenerateReqInput):
|
||||
async for out in tokenizer_manager.generate_request(obj):
|
||||
await handle_token_logprobs_results(obj, out)
|
||||
yield out
|
||||
|
||||
|
||||
async def make_openai_style_logprobs(
|
||||
prefill_token_logprobs=None,
|
||||
decode_token_logprobs=None,
|
||||
prefill_top_logprobs=None,
|
||||
decode_top_logprobs=None,
|
||||
):
|
||||
ret_logprobs = LogProbs()
|
||||
|
||||
def append_token_logprobs(token_logprobs):
|
||||
for logprob, _, token_text in token_logprobs:
|
||||
ret_logprobs.tokens.append(token_text)
|
||||
ret_logprobs.token_logprobs.append(logprob)
|
||||
|
||||
# Not Supported yet
|
||||
ret_logprobs.text_offset.append(-1)
|
||||
|
||||
def append_top_logprobs(top_logprobs):
|
||||
for tokens in top_logprobs:
|
||||
if tokens is not None:
|
||||
ret_logprobs.top_logprobs.append(
|
||||
{token[2]: token[0] for token in tokens}
|
||||
)
|
||||
else:
|
||||
ret_logprobs.top_logprobs.append(None)
|
||||
|
||||
if prefill_token_logprobs is not None:
|
||||
append_token_logprobs(prefill_token_logprobs)
|
||||
if decode_token_logprobs is not None:
|
||||
append_token_logprobs(decode_token_logprobs)
|
||||
if prefill_top_logprobs is not None:
|
||||
append_top_logprobs(prefill_top_logprobs)
|
||||
if decode_top_logprobs is not None:
|
||||
append_top_logprobs(decode_top_logprobs)
|
||||
|
||||
return ret_logprobs
|
||||
|
||||
|
||||
@app.post("/generate")
|
||||
async def generate_request(obj: GenerateReqInput):
|
||||
obj.post_init()
|
||||
|
||||
if obj.stream:
|
||||
|
||||
async def stream_results():
|
||||
async for out in stream_generator(obj):
|
||||
yield f"data: {json.dumps(out, ensure_ascii=False)}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
return StreamingResponse(stream_results(), media_type="text/event-stream")
|
||||
|
||||
ret = await tokenizer_manager.generate_request(obj).__anext__()
|
||||
await handle_token_logprobs_results(obj, ret)
|
||||
|
||||
return ret
|
||||
|
||||
|
||||
@app.post("/v1/completions")
|
||||
async def v1_completions(raw_request: Request):
|
||||
request_json = await raw_request.json()
|
||||
@@ -500,27 +438,97 @@ async def v1_chat_completions(raw_request: Request):
|
||||
return response
|
||||
|
||||
|
||||
def launch_server(server_args: ServerArgs, pipe_finish_writer):
|
||||
global tokenizer_manager
|
||||
async def make_openai_style_logprobs(
|
||||
prefill_token_logprobs=None,
|
||||
decode_token_logprobs=None,
|
||||
prefill_top_logprobs=None,
|
||||
decode_top_logprobs=None,
|
||||
):
|
||||
ret_logprobs = LogProbs()
|
||||
|
||||
def append_token_logprobs(token_logprobs):
|
||||
for logprob, _, token_text in token_logprobs:
|
||||
ret_logprobs.tokens.append(token_text)
|
||||
ret_logprobs.token_logprobs.append(logprob)
|
||||
|
||||
# Not Supported yet
|
||||
ret_logprobs.text_offset.append(-1)
|
||||
|
||||
def append_top_logprobs(top_logprobs):
|
||||
for tokens in top_logprobs:
|
||||
if tokens is not None:
|
||||
ret_logprobs.top_logprobs.append(
|
||||
{token[2]: token[0] for token in tokens}
|
||||
)
|
||||
else:
|
||||
ret_logprobs.top_logprobs.append(None)
|
||||
|
||||
if prefill_token_logprobs is not None:
|
||||
append_token_logprobs(prefill_token_logprobs)
|
||||
if decode_token_logprobs is not None:
|
||||
append_token_logprobs(decode_token_logprobs)
|
||||
if prefill_top_logprobs is not None:
|
||||
append_top_logprobs(prefill_top_logprobs)
|
||||
if decode_top_logprobs is not None:
|
||||
append_top_logprobs(decode_top_logprobs)
|
||||
|
||||
return ret_logprobs
|
||||
|
||||
|
||||
def load_chat_template_for_openai_api(chat_template_arg):
|
||||
global chat_template_name
|
||||
|
||||
if server_args.enable_flashinfer:
|
||||
from sglang.srt.utils import assert_pkg_version
|
||||
assert_pkg_version("flashinfer", "0.0.4")
|
||||
print(f"Use chat template: {chat_template_arg}")
|
||||
if not chat_template_exists(chat_template_arg):
|
||||
if not os.path.exists(chat_template_arg):
|
||||
raise RuntimeError(
|
||||
f"Chat template {chat_template_arg} is not a built-in template name "
|
||||
"or a valid chat template file path."
|
||||
)
|
||||
with open(chat_template_arg, "r") as filep:
|
||||
template = json.load(filep)
|
||||
try:
|
||||
sep_style = SeparatorStyle[template["sep_style"]]
|
||||
except KeyError:
|
||||
raise ValueError(
|
||||
f"Unknown separator style: {template['sep_style']}"
|
||||
) from None
|
||||
register_conv_template(
|
||||
Conversation(
|
||||
name=template["name"],
|
||||
system_template=template["system"] + "\n{system_message}",
|
||||
system_message=template.get("system_message", ""),
|
||||
roles=(template["user"], template["assistant"]),
|
||||
sep_style=sep_style,
|
||||
sep=template.get("sep", "\n"),
|
||||
stop_str=template["stop_str"],
|
||||
),
|
||||
override=True,
|
||||
)
|
||||
chat_template_name = template["name"]
|
||||
else:
|
||||
chat_template_name = chat_template_arg
|
||||
|
||||
# start show time thread
|
||||
|
||||
def launch_server(server_args: ServerArgs, pipe_finish_writer):
|
||||
global tokenizer_manager
|
||||
|
||||
# Set global environments
|
||||
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
||||
if server_args.show_time_cost:
|
||||
enable_show_time_cost()
|
||||
|
||||
# disable disk cache if needed
|
||||
if server_args.disable_disk_cache:
|
||||
disable_cache()
|
||||
if server_args.enable_flashinfer:
|
||||
assert_pkg_version("flashinfer", "0.0.4")
|
||||
if server_args.chat_template:
|
||||
# TODO: replace this with huggingface transformers template
|
||||
load_chat_template_for_openai_api(server_args.chat_template)
|
||||
|
||||
# Handle ports
|
||||
server_args.port, server_args.additional_ports = handle_port_init(
|
||||
# Allocate ports
|
||||
server_args.port, server_args.additional_ports = allocate_init_ports(
|
||||
server_args.port, server_args.additional_ports, server_args.tp_size
|
||||
)
|
||||
|
||||
port_args = PortArgs(
|
||||
tokenizer_port=server_args.additional_ports[0],
|
||||
router_port=server_args.additional_ports[1],
|
||||
@@ -529,39 +537,6 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer):
|
||||
model_rpc_ports=server_args.additional_ports[4:],
|
||||
)
|
||||
|
||||
# Load chat template if needed
|
||||
if server_args.chat_template is not None:
|
||||
print(f"Use chat template: {server_args.chat_template}")
|
||||
if not chat_template_exists(server_args.chat_template):
|
||||
if not os.path.exists(server_args.chat_template):
|
||||
raise RuntimeError(
|
||||
f"Chat template {server_args.chat_template} is not a built-in template name "
|
||||
"or a valid chat template file path."
|
||||
)
|
||||
with open(server_args.chat_template, "r") as filep:
|
||||
template = json.load(filep)
|
||||
try:
|
||||
sep_style = SeparatorStyle[template["sep_style"]]
|
||||
except KeyError:
|
||||
raise ValueError(
|
||||
f"Unknown separator style: {template['sep_style']}"
|
||||
) from None
|
||||
register_conv_template(
|
||||
Conversation(
|
||||
name=template["name"],
|
||||
system_template=template["system"] + "\n{system_message}",
|
||||
system_message=template.get("system_message", ""),
|
||||
roles=(template["user"], template["assistant"]),
|
||||
sep_style=sep_style,
|
||||
sep=template.get("sep", "\n"),
|
||||
stop_str=template["stop_str"],
|
||||
),
|
||||
override=True,
|
||||
)
|
||||
chat_template_name = template["name"]
|
||||
else:
|
||||
chat_template_name = server_args.chat_template
|
||||
|
||||
# Launch processes
|
||||
tokenizer_manager = TokenizerManager(server_args, port_args)
|
||||
pipe_router_reader, pipe_router_writer = mp.Pipe(duplex=False)
|
||||
@@ -593,31 +568,21 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer):
|
||||
if router_init_state != "init ok" or detoken_init_state != "init ok":
|
||||
proc_router.kill()
|
||||
proc_detoken.kill()
|
||||
print("router init state:", router_init_state)
|
||||
print("detoken init state:", detoken_init_state)
|
||||
print(f"Initialization failed. router_init_state: {router_init_state}", flush=True)
|
||||
print(f"Initialization failed. detoken_init_state: {detoken_init_state}", flush=True)
|
||||
sys.exit(1)
|
||||
|
||||
assert proc_router.is_alive() and proc_detoken.is_alive()
|
||||
|
||||
if server_args.api_key and server_args.api_key != "":
|
||||
app.add_middleware(APIKeyValidatorMiddleware, api_key=server_args.api_key)
|
||||
|
||||
def _launch_server():
|
||||
uvicorn.run(
|
||||
app,
|
||||
host=server_args.host,
|
||||
port=server_args.port,
|
||||
log_level=server_args.log_level,
|
||||
timeout_keep_alive=5,
|
||||
loop="uvloop",
|
||||
)
|
||||
|
||||
def _wait_and_warmup():
|
||||
headers = {}
|
||||
url = server_args.url()
|
||||
if server_args.api_key and server_args.api_key != "":
|
||||
if server_args.api_key:
|
||||
headers[API_KEY_HEADER_NAME] = server_args.api_key
|
||||
|
||||
# Wait until the server is launched
|
||||
for _ in range(120):
|
||||
time.sleep(0.5)
|
||||
try:
|
||||
@@ -625,16 +590,9 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer):
|
||||
break
|
||||
except requests.exceptions.RequestException as e:
|
||||
pass
|
||||
else:
|
||||
if pipe_finish_writer is not None:
|
||||
pipe_finish_writer.send(str(e))
|
||||
else:
|
||||
print(e, flush=True)
|
||||
return
|
||||
|
||||
# Warmup
|
||||
# Send a warmup request
|
||||
try:
|
||||
# print("Warmup...", flush=True)
|
||||
res = requests.post(
|
||||
url + "/generate",
|
||||
json={
|
||||
@@ -647,14 +605,12 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer):
|
||||
headers=headers,
|
||||
timeout=60,
|
||||
)
|
||||
# print(f"Warmup done. model response: {res.json()['text']}")
|
||||
# print("=" * 20, "Server is ready", "=" * 20, flush=True)
|
||||
except requests.exceptions.RequestException as e:
|
||||
assert res.status_code == 200
|
||||
except Exception as e:
|
||||
if pipe_finish_writer is not None:
|
||||
pipe_finish_writer.send(str(e))
|
||||
else:
|
||||
print(e, flush=True)
|
||||
return
|
||||
pipe_finish_writer.send(get_exception_traceback())
|
||||
print(f"Initialization failed. warmup error: {e}")
|
||||
raise e
|
||||
|
||||
if pipe_finish_writer is not None:
|
||||
pipe_finish_writer.send("init ok")
|
||||
@@ -662,7 +618,14 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer):
|
||||
t = threading.Thread(target=_wait_and_warmup)
|
||||
t.start()
|
||||
try:
|
||||
_launch_server()
|
||||
uvicorn.run(
|
||||
app,
|
||||
host=server_args.host,
|
||||
port=server_args.port,
|
||||
log_level=server_args.log_level,
|
||||
timeout_keep_alive=5,
|
||||
loop="uvloop",
|
||||
)
|
||||
finally:
|
||||
t.join()
|
||||
|
||||
@@ -670,52 +633,16 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer):
|
||||
class Runtime:
|
||||
def __init__(
|
||||
self,
|
||||
model_path: str,
|
||||
tokenizer_path: Optional[str] = None,
|
||||
load_format: str = "auto",
|
||||
tokenizer_mode: str = "auto",
|
||||
trust_remote_code: bool = True,
|
||||
mem_fraction_static: float = ServerArgs.mem_fraction_static,
|
||||
max_prefill_num_token: int = ServerArgs.max_prefill_num_token,
|
||||
context_length: int = ServerArgs.context_length,
|
||||
tp_size: int = 1,
|
||||
schedule_heuristic: str = "lpm",
|
||||
attention_reduce_in_fp32: bool = False,
|
||||
random_seed: int = 42,
|
||||
log_level: str = "error",
|
||||
disable_radix_cache: bool = False,
|
||||
enable_flashinfer: bool = False,
|
||||
disable_regex_jump_forward: bool = False,
|
||||
disable_disk_cache: bool = False,
|
||||
api_key: str = "",
|
||||
port: Optional[int] = None,
|
||||
additional_ports: Optional[Union[List[int], int]] = None,
|
||||
log_evel="error",
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
host = "127.0.0.1"
|
||||
port, additional_ports = handle_port_init(port, additional_ports, tp_size)
|
||||
self.server_args = ServerArgs(
|
||||
model_path=model_path,
|
||||
tokenizer_path=tokenizer_path,
|
||||
host=host,
|
||||
port=port,
|
||||
additional_ports=additional_ports,
|
||||
load_format=load_format,
|
||||
tokenizer_mode=tokenizer_mode,
|
||||
trust_remote_code=trust_remote_code,
|
||||
mem_fraction_static=mem_fraction_static,
|
||||
max_prefill_num_token=max_prefill_num_token,
|
||||
context_length=context_length,
|
||||
tp_size=tp_size,
|
||||
schedule_heuristic=schedule_heuristic,
|
||||
attention_reduce_in_fp32=attention_reduce_in_fp32,
|
||||
random_seed=random_seed,
|
||||
log_level=log_level,
|
||||
disable_radix_cache=disable_radix_cache,
|
||||
enable_flashinfer=enable_flashinfer,
|
||||
disable_regex_jump_forward=disable_regex_jump_forward,
|
||||
disable_disk_cache=disable_disk_cache,
|
||||
api_key=api_key,
|
||||
)
|
||||
"""See the arguments in server_args.py::ServerArgs"""
|
||||
self.server_args = ServerArgs(*args, log_level=log_evel, **kwargs)
|
||||
|
||||
# Pre-allocate ports
|
||||
self.server_args.port, self.server_args.additional_ports = allocate_init_ports(
|
||||
self.server_args.port, self.server_args.additional_ports, self.server_args.tp_size)
|
||||
|
||||
self.url = self.server_args.url()
|
||||
self.generate_url = (
|
||||
@@ -736,7 +663,7 @@ class Runtime:
|
||||
|
||||
if init_state != "init ok":
|
||||
self.shutdown()
|
||||
raise RuntimeError("Launch failed. Please see the error messages above.")
|
||||
raise RuntimeError("Initialization failed. Please see the error messages above.")
|
||||
|
||||
self.endpoint = RuntimeEndpoint(self.url)
|
||||
|
||||
@@ -765,13 +692,12 @@ class Runtime:
|
||||
self,
|
||||
prompt: str,
|
||||
sampling_params,
|
||||
) -> None:
|
||||
):
|
||||
json_data = {
|
||||
"text": prompt,
|
||||
"sampling_params": sampling_params,
|
||||
"stream": True,
|
||||
}
|
||||
|
||||
pos = 0
|
||||
|
||||
timeout = aiohttp.ClientTimeout(total=3 * 3600)
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
"""The arguments of the server."""
|
||||
|
||||
import argparse
|
||||
import dataclasses
|
||||
from typing import List, Optional, Union
|
||||
@@ -5,33 +7,44 @@ from typing import List, Optional, Union
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ServerArgs:
|
||||
# Model and tokenizer
|
||||
model_path: str
|
||||
tokenizer_path: Optional[str] = None
|
||||
host: str = "127.0.0.1"
|
||||
port: int = 30000
|
||||
additional_ports: Optional[Union[List[int], int]] = None
|
||||
load_format: str = "auto"
|
||||
tokenizer_mode: str = "auto"
|
||||
chat_template: Optional[str] = None
|
||||
trust_remote_code: bool = True
|
||||
context_length: Optional[int] = None
|
||||
|
||||
# Port
|
||||
host: str = "127.0.0.1"
|
||||
port: int = 30000
|
||||
additional_ports: Optional[Union[List[int], int]] = None
|
||||
|
||||
# Memory and scheduling
|
||||
mem_fraction_static: Optional[float] = None
|
||||
max_prefill_num_token: Optional[int] = None
|
||||
context_length: Optional[int] = None
|
||||
tp_size: int = 1
|
||||
schedule_heuristic: str = "lpm"
|
||||
schedule_conservativeness: float = 1.0
|
||||
attention_reduce_in_fp32: bool = False
|
||||
random_seed: int = 42
|
||||
|
||||
# Other runtime options
|
||||
tp_size: int = 1
|
||||
stream_interval: int = 8
|
||||
random_seed: int = 42
|
||||
|
||||
# Logging
|
||||
log_level: str = "info"
|
||||
disable_log_stats: bool = False
|
||||
log_stats_interval: int = 10
|
||||
log_level: str = "info"
|
||||
api_key: str = ""
|
||||
show_time_cost: bool = False
|
||||
|
||||
# optional modes
|
||||
disable_radix_cache: bool = False
|
||||
# Other
|
||||
api_key: str = ""
|
||||
|
||||
# Optimization/debug options
|
||||
enable_flashinfer: bool = False
|
||||
attention_reduce_in_fp32: bool = False
|
||||
disable_radix_cache: bool = False
|
||||
disable_regex_jump_forward: bool = False
|
||||
disable_disk_cache: bool = False
|
||||
|
||||
@@ -66,15 +79,16 @@ class ServerArgs:
|
||||
default=ServerArgs.tokenizer_path,
|
||||
help="The path of the tokenizer.",
|
||||
)
|
||||
parser.add_argument("--host", type=str, default=ServerArgs.host)
|
||||
parser.add_argument("--port", type=int, default=ServerArgs.port)
|
||||
# we want to be able to pass a list of ports
|
||||
parser.add_argument("--host", type=str, default=ServerArgs.host,
|
||||
help="The host of the server.")
|
||||
parser.add_argument("--port", type=int, default=ServerArgs.port,
|
||||
help="The port of the server.")
|
||||
parser.add_argument(
|
||||
"--additional-ports",
|
||||
type=int,
|
||||
nargs="*",
|
||||
default=[],
|
||||
help="Additional ports specified for launching server.",
|
||||
help="Additional ports specified for the server.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--load-format",
|
||||
@@ -112,6 +126,12 @@ class ServerArgs:
|
||||
action="store_true",
|
||||
help="Whether or not to allow for custom models defined on the Hub in their own modeling files.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--context-length",
|
||||
type=int,
|
||||
default=ServerArgs.context_length,
|
||||
help="The model's maximum context length. Defaults to None (will use the value from the model's config.json instead).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mem-fraction-static",
|
||||
type=float,
|
||||
@@ -124,18 +144,6 @@ class ServerArgs:
|
||||
default=ServerArgs.max_prefill_num_token,
|
||||
help="The maximum number of tokens in a prefill batch. The real bound will be the maximum of this value and the model's maximum context length.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--context-length",
|
||||
type=int,
|
||||
default=ServerArgs.context_length,
|
||||
help="The model's maximum context length. Use this to reduce the context length to save memory. Defaults to None (will use the value from the model's config.json instead).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tp-size",
|
||||
type=int,
|
||||
default=ServerArgs.tp_size,
|
||||
help="Tensor parallelism degree.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--schedule-heuristic",
|
||||
type=str,
|
||||
@@ -149,15 +157,10 @@ class ServerArgs:
|
||||
help="How conservative the schedule policy is. A larger value means more conservative scheduling. Use a larger value if you see requests being retracted frequently.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--random-seed",
|
||||
"--tp-size",
|
||||
type=int,
|
||||
default=ServerArgs.random_seed,
|
||||
help="Random seed.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--attention-reduce-in-fp32",
|
||||
action="store_true",
|
||||
help="Cast the intermidiate attention results to fp32 to avoid possible crashes related to fp16.",
|
||||
default=ServerArgs.tp_size,
|
||||
help="Tensor parallelism size.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--stream-interval",
|
||||
@@ -165,11 +168,17 @@ class ServerArgs:
|
||||
default=ServerArgs.stream_interval,
|
||||
help="The interval (or buffer size) for streaming in terms of the token length. A smaller value makes streaming smoother, while a larger value makes the throughput higher",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--random-seed",
|
||||
type=int,
|
||||
default=ServerArgs.random_seed,
|
||||
help="Random seed.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--log-level",
|
||||
type=str,
|
||||
default=ServerArgs.log_level,
|
||||
help="Log level",
|
||||
help="Logging level",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--disable-log-stats",
|
||||
@@ -182,29 +191,34 @@ class ServerArgs:
|
||||
default=ServerArgs.log_stats_interval,
|
||||
help="Log stats interval in second.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--api-key",
|
||||
type=str,
|
||||
default=ServerArgs.api_key,
|
||||
help="Set API Key",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--show-time-cost",
|
||||
action="store_true",
|
||||
help="Show time cost of custom marks",
|
||||
)
|
||||
|
||||
# optional modes
|
||||
parser.add_argument(
|
||||
"--disable-radix-cache",
|
||||
action="store_true",
|
||||
help="Disable RadixAttention",
|
||||
"--api-key",
|
||||
type=str,
|
||||
default=ServerArgs.api_key,
|
||||
help="Set API key of the server",
|
||||
)
|
||||
|
||||
# Optimization/debug options
|
||||
parser.add_argument(
|
||||
"--enable-flashinfer",
|
||||
action="store_true",
|
||||
help="Enable flashinfer inference kernels",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--attention-reduce-in-fp32",
|
||||
action="store_true",
|
||||
help="Cast the intermidiate attention results to fp32 to avoid possible crashes related to fp16.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--disable-radix-cache",
|
||||
action="store_true",
|
||||
help="Disable RadixAttention",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--disable-regex-jump-forward",
|
||||
action="store_true",
|
||||
@@ -224,13 +238,13 @@ class ServerArgs:
|
||||
def url(self):
|
||||
return f"http://{self.host}:{self.port}"
|
||||
|
||||
def get_optional_modes_logging(self):
|
||||
def print_mode_args(self):
|
||||
return (
|
||||
f"disable_radix_cache={self.disable_radix_cache}, "
|
||||
f"enable_flashinfer={self.enable_flashinfer}, "
|
||||
f"attention_reduce_in_fp32={self.attention_reduce_in_fp32}"
|
||||
f"disable_radix_cache={self.disable_radix_cache}, "
|
||||
f"disable_regex_jump_forward={self.disable_regex_jump_forward}, "
|
||||
f"disable_disk_cache={self.disable_disk_cache}, "
|
||||
f"attention_reduce_in_fp32={self.attention_reduce_in_fp32}"
|
||||
)
|
||||
|
||||
|
||||
@@ -240,4 +254,4 @@ class PortArgs:
|
||||
router_port: int
|
||||
detokenizer_port: int
|
||||
nccl_port: int
|
||||
model_rpc_ports: List[int]
|
||||
model_rpc_ports: List[int]
|
||||
@@ -10,9 +10,12 @@ from io import BytesIO
|
||||
from typing import List, Optional
|
||||
|
||||
import numpy as np
|
||||
import pydantic
|
||||
import requests
|
||||
import torch
|
||||
from packaging import version as pkg_version
|
||||
from pydantic import BaseModel
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
|
||||
show_time_cost = False
|
||||
time_infos = {}
|
||||
@@ -120,7 +123,7 @@ def check_port(port):
|
||||
return False
|
||||
|
||||
|
||||
def handle_port_init(
|
||||
def allocate_init_ports(
|
||||
port: Optional[int] = None,
|
||||
additional_ports: Optional[List[int]] = None,
|
||||
tp_size: int = 1,
|
||||
@@ -159,8 +162,6 @@ def get_exception_traceback():
|
||||
|
||||
|
||||
def get_int_token_logit_bias(tokenizer, vocab_size):
|
||||
from transformers import LlamaTokenizer, LlamaTokenizerFast
|
||||
|
||||
# a bug when model's vocab size > tokenizer.vocab_size
|
||||
vocab_size = tokenizer.vocab_size
|
||||
logit_bias = np.zeros(vocab_size, dtype=np.float32)
|
||||
@@ -281,3 +282,32 @@ def assert_pkg_version(pkg: str, min_version: str):
|
||||
)
|
||||
except PackageNotFoundError:
|
||||
raise Exception(f"{pkg} with minimum required version {min_version} is not installed")
|
||||
|
||||
|
||||
API_KEY_HEADER_NAME = "X-API-Key"
|
||||
|
||||
|
||||
class APIKeyValidatorMiddleware(BaseHTTPMiddleware):
|
||||
def __init__(self, app, api_key: str):
|
||||
super().__init__(app)
|
||||
self.api_key = api_key
|
||||
|
||||
async def dispatch(self, request, call_next):
|
||||
# extract API key from the request headers
|
||||
api_key_header = request.headers.get(API_KEY_HEADER_NAME)
|
||||
if not api_key_header or api_key_header != self.api_key:
|
||||
return JSONResponse(
|
||||
status_code=403,
|
||||
content={"detail": "Invalid API Key"},
|
||||
)
|
||||
response = await call_next(request)
|
||||
return response
|
||||
|
||||
# FIXME: Remove this once we drop support for pydantic 1.x
|
||||
IS_PYDANTIC_1 = int(pydantic.VERSION.split(".")[0]) == 1
|
||||
|
||||
|
||||
def jsonify_pydantic_model(obj: BaseModel):
|
||||
if IS_PYDANTIC_1:
|
||||
return obj.json(ensure_ascii=False)
|
||||
return obj.model_dump_json()
|
||||
|
||||
Reference in New Issue
Block a user