789 lines
27 KiB
Python
789 lines
27 KiB
Python
"""SRT: SGLang Runtime"""
|
|
|
|
import asyncio
|
|
import dataclasses
|
|
import json
|
|
import multiprocessing as mp
|
|
import os
|
|
import sys
|
|
import threading
|
|
import time
|
|
from typing import List, Optional, Union
|
|
|
|
# Fix a Python bug
|
|
setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
|
|
|
|
import aiohttp
|
|
import psutil
|
|
import pydantic
|
|
import requests
|
|
import uvicorn
|
|
import uvloop
|
|
from fastapi import FastAPI, HTTPException, Request
|
|
from fastapi.responses import Response, StreamingResponse
|
|
from pydantic import BaseModel
|
|
from starlette.middleware.base import BaseHTTPMiddleware
|
|
from starlette.responses import JSONResponse
|
|
|
|
from sglang.backend.runtime_endpoint import RuntimeEndpoint
|
|
from sglang.srt.constrained import disable_cache
|
|
from sglang.srt.conversation import (
|
|
Conversation,
|
|
SeparatorStyle,
|
|
chat_template_exists,
|
|
generate_chat_conv,
|
|
register_conv_template,
|
|
)
|
|
from sglang.srt.hf_transformers_utils import get_tokenizer
|
|
from sglang.srt.managers.detokenizer_manager import start_detokenizer_process
|
|
from sglang.srt.managers.io_struct import DetokenizeReqInput, GenerateReqInput
|
|
from sglang.srt.managers.openai_protocol import (
|
|
ChatCompletionRequest,
|
|
ChatCompletionResponse,
|
|
ChatCompletionResponseChoice,
|
|
ChatCompletionResponseStreamChoice,
|
|
ChatCompletionStreamResponse,
|
|
ChatMessage,
|
|
CompletionRequest,
|
|
CompletionResponse,
|
|
CompletionResponseChoice,
|
|
CompletionResponseStreamChoice,
|
|
CompletionStreamResponse,
|
|
DeltaMessage,
|
|
LogProbs,
|
|
UsageInfo,
|
|
)
|
|
from sglang.srt.managers.router.manager import start_router_process
|
|
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
|
from sglang.srt.utils import enable_show_time_cost, handle_port_init
|
|
|
|
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
|
|
|
API_KEY_HEADER_NAME = "X-API-Key"
|
|
|
|
|
|
class APIKeyValidatorMiddleware(BaseHTTPMiddleware):
|
|
def __init__(self, app, api_key: str):
|
|
super().__init__(app)
|
|
self.api_key = api_key
|
|
|
|
async def dispatch(self, request: Request, call_next):
|
|
# extract API key from the request headers
|
|
api_key_header = request.headers.get(API_KEY_HEADER_NAME)
|
|
if not api_key_header or api_key_header != self.api_key:
|
|
return JSONResponse(
|
|
status_code=403,
|
|
content={"detail": "Invalid API Key"},
|
|
)
|
|
response = await call_next(request)
|
|
return response
|
|
|
|
|
|
app = FastAPI()
|
|
tokenizer_manager = None
|
|
chat_template_name = None
|
|
|
|
|
|
# FIXME: Remove this once we drop support for pydantic 1.x
|
|
IS_PYDANTIC_1 = int(pydantic.VERSION.split(".")[0]) == 1
|
|
|
|
|
|
def jsonify_pydantic_model(obj: BaseModel):
|
|
if IS_PYDANTIC_1:
|
|
return obj.json(ensure_ascii=False)
|
|
return obj.model_dump_json()
|
|
|
|
|
|
@app.get("/health")
|
|
async def health() -> Response:
|
|
"""Health check."""
|
|
return Response(status_code=200)
|
|
|
|
|
|
@app.get("/get_model_info")
|
|
async def get_model_info():
|
|
result = {
|
|
"model_path": tokenizer_manager.model_path,
|
|
}
|
|
return result
|
|
|
|
|
|
@app.get("/get_server_args")
|
|
async def get_server_args():
|
|
return dataclasses.asdict(tokenizer_manager.server_args)
|
|
|
|
|
|
@app.get("/flush_cache")
|
|
async def flush_cache():
|
|
await tokenizer_manager.flush_cache()
|
|
return Response(
|
|
content="Cache flushed.\nPlease check backend logs for more details. "
|
|
"(When there are running or waiting requests, the operation will not be performed.)\n",
|
|
status_code=200,
|
|
)
|
|
|
|
|
|
async def detokenize_logprob_tokens(token_logprobs, decode_to_text):
|
|
if not decode_to_text:
|
|
return [(logprob, token_id, None) for logprob, token_id in token_logprobs]
|
|
|
|
token_ids = [tid for _, tid in token_logprobs]
|
|
token_texts = await tokenizer_manager.detokenize(DetokenizeReqInput(token_ids))
|
|
return [
|
|
(logprob, token_id, token_text)
|
|
for (logprob, token_id), token_text, in zip(token_logprobs, token_texts)
|
|
]
|
|
|
|
|
|
async def detokenize_top_logprobs_tokens(top_logprobs, decode_to_text):
|
|
for i, t in enumerate(top_logprobs):
|
|
if top_logprobs[i] is not None:
|
|
top_logprobs[i] = await detokenize_logprob_tokens(t, decode_to_text)
|
|
return top_logprobs
|
|
|
|
|
|
async def handle_token_logprobs_results(obj: GenerateReqInput, ret):
|
|
"""Handle the token logprobs results, convert token ids to text if needed.
|
|
|
|
Args:
|
|
obj (GenerateReqInput): The request object.
|
|
ret (Union[Dict, List[Dict]]): The response object.
|
|
"""
|
|
# NOTE: This is because the multiple requests in one http request.
|
|
|
|
async def convert_style(r, return_text):
|
|
r["meta_info"]["prefill_token_logprobs"] = await detokenize_logprob_tokens(
|
|
r["meta_info"]["prefill_token_logprobs"], return_text
|
|
)
|
|
r["meta_info"]["decode_token_logprobs"] = await detokenize_logprob_tokens(
|
|
r["meta_info"]["decode_token_logprobs"], return_text
|
|
)
|
|
r["meta_info"]["prefill_top_logprobs"] = await detokenize_top_logprobs_tokens(
|
|
r["meta_info"]["prefill_top_logprobs"], return_text
|
|
)
|
|
r["meta_info"]["decode_top_logprobs"] = await detokenize_top_logprobs_tokens(
|
|
r["meta_info"]["decode_top_logprobs"], return_text
|
|
)
|
|
|
|
if isinstance(obj.text, str):
|
|
if obj.return_logprob:
|
|
await convert_style(ret, obj.return_text_in_logprobs)
|
|
else:
|
|
for i, r in enumerate(ret):
|
|
if obj.return_logprob[i]:
|
|
await convert_style(r, obj.return_text_in_logprobs)
|
|
|
|
|
|
async def stream_generator(obj: GenerateReqInput):
|
|
async for out in tokenizer_manager.generate_request(obj):
|
|
await handle_token_logprobs_results(obj, out)
|
|
yield out
|
|
|
|
|
|
async def make_openai_style_logprobs(
|
|
prefill_token_logprobs=None,
|
|
decode_token_logprobs=None,
|
|
prefill_top_logprobs=None,
|
|
decode_top_logprobs=None,
|
|
):
|
|
ret_logprobs = LogProbs()
|
|
|
|
def append_token_logprobs(token_logprobs):
|
|
for logprob, _, token_text in token_logprobs:
|
|
ret_logprobs.tokens.append(token_text)
|
|
ret_logprobs.token_logprobs.append(logprob)
|
|
|
|
# Not Supported yet
|
|
ret_logprobs.text_offset.append(-1)
|
|
|
|
def append_top_logprobs(top_logprobs):
|
|
for tokens in top_logprobs:
|
|
if tokens is not None:
|
|
ret_logprobs.top_logprobs.append(
|
|
{token[2]: token[0] for token in tokens}
|
|
)
|
|
else:
|
|
ret_logprobs.top_logprobs.append(None)
|
|
|
|
if prefill_token_logprobs is not None:
|
|
append_token_logprobs(prefill_token_logprobs)
|
|
if decode_token_logprobs is not None:
|
|
append_token_logprobs(decode_token_logprobs)
|
|
if prefill_top_logprobs is not None:
|
|
append_top_logprobs(prefill_top_logprobs)
|
|
if decode_top_logprobs is not None:
|
|
append_top_logprobs(decode_top_logprobs)
|
|
|
|
return ret_logprobs
|
|
|
|
|
|
@app.post("/generate")
|
|
async def generate_request(obj: GenerateReqInput):
|
|
obj.post_init()
|
|
|
|
if obj.stream:
|
|
|
|
async def stream_results():
|
|
async for out in stream_generator(obj):
|
|
yield f"data: {json.dumps(out, ensure_ascii=False)}\n\n"
|
|
yield "data: [DONE]\n\n"
|
|
|
|
return StreamingResponse(stream_results(), media_type="text/event-stream")
|
|
|
|
ret = await tokenizer_manager.generate_request(obj).__anext__()
|
|
await handle_token_logprobs_results(obj, ret)
|
|
|
|
return ret
|
|
|
|
|
|
@app.post("/v1/completions")
|
|
async def v1_completions(raw_request: Request):
|
|
request_json = await raw_request.json()
|
|
request = CompletionRequest(**request_json)
|
|
|
|
# TODO: Validate the request and return HTTPStatus.BAD_REQUEST if invalid.
|
|
assert request.n == 1
|
|
|
|
adapted_request = GenerateReqInput(
|
|
text=request.prompt,
|
|
sampling_params={
|
|
"temperature": request.temperature,
|
|
"max_new_tokens": request.max_tokens,
|
|
"stop": request.stop,
|
|
"top_p": request.top_p,
|
|
"presence_penalty": request.presence_penalty,
|
|
"frequency_penalty": request.frequency_penalty,
|
|
"regex": request.regex,
|
|
},
|
|
return_logprob=request.logprobs is not None and request.logprobs > 0,
|
|
top_logprobs_num=request.logprobs if request.logprobs is not None else 0,
|
|
return_text_in_logprobs=True,
|
|
stream=request.stream,
|
|
)
|
|
adapted_request.post_init()
|
|
|
|
if adapted_request.stream:
|
|
|
|
async def gnerate_stream_resp():
|
|
stream_buffer = ""
|
|
n_prev_token = 0
|
|
async for content in stream_generator(adapted_request):
|
|
text = content["text"]
|
|
prompt_tokens = content["meta_info"]["prompt_tokens"]
|
|
completion_tokens = content["meta_info"]["completion_tokens"]
|
|
|
|
if not stream_buffer: # The first chunk
|
|
if request.echo:
|
|
# Prepend prompt in response text.
|
|
text = request.prompt + text
|
|
|
|
if request.logprobs:
|
|
# The first chunk and echo is enabled.
|
|
if not stream_buffer and request.echo:
|
|
prefill_token_logprobs = content["meta_info"][
|
|
"prefill_token_logprobs"
|
|
]
|
|
prefill_top_logprobs = content["meta_info"][
|
|
"prefill_top_logprobs"
|
|
]
|
|
else:
|
|
prefill_token_logprobs = None
|
|
prefill_top_logprobs = None
|
|
|
|
logprobs = await make_openai_style_logprobs(
|
|
prefill_token_logprobs=prefill_token_logprobs,
|
|
prefill_top_logprobs=prefill_top_logprobs,
|
|
decode_token_logprobs=content["meta_info"][
|
|
"decode_token_logprobs"
|
|
][n_prev_token:],
|
|
decode_top_logprobs=content["meta_info"]["decode_top_logprobs"][
|
|
n_prev_token:
|
|
],
|
|
)
|
|
|
|
n_prev_token = len(content["meta_info"]["decode_token_logprobs"])
|
|
else:
|
|
logprobs = None
|
|
|
|
delta = text[len(stream_buffer) :]
|
|
stream_buffer = content["text"]
|
|
choice_data = CompletionResponseStreamChoice(
|
|
index=0,
|
|
text=delta,
|
|
logprobs=logprobs,
|
|
finish_reason=None,
|
|
)
|
|
chunk = CompletionStreamResponse(
|
|
id=content["meta_info"]["id"],
|
|
object="text_completion",
|
|
choices=[choice_data],
|
|
model=request.model,
|
|
usage=UsageInfo(
|
|
prompt_tokens=prompt_tokens,
|
|
completion_tokens=completion_tokens,
|
|
total_tokens=prompt_tokens + completion_tokens,
|
|
),
|
|
)
|
|
yield f"data: {jsonify_pydantic_model(chunk)}\n\n"
|
|
yield "data: [DONE]\n\n"
|
|
|
|
return StreamingResponse(gnerate_stream_resp(), media_type="text/event-stream")
|
|
|
|
# Non-streaming response.
|
|
ret = await generate_request(adapted_request)
|
|
ret = ret[0] if isinstance(ret, list) else ret
|
|
|
|
prompt_tokens = ret["meta_info"]["prompt_tokens"]
|
|
completion_tokens = ret["meta_info"]["completion_tokens"]
|
|
text = ret["text"]
|
|
if request.echo:
|
|
text = request.prompt + text
|
|
|
|
if request.logprobs:
|
|
if request.echo:
|
|
prefill_token_logprobs = ret["meta_info"]["prefill_token_logprobs"]
|
|
prefill_top_logprobs = ret["meta_info"]["prefill_top_logprobs"]
|
|
else:
|
|
prefill_token_logprobs = None
|
|
prefill_top_logprobs = None
|
|
|
|
logprobs = await make_openai_style_logprobs(
|
|
prefill_token_logprobs=prefill_token_logprobs,
|
|
prefill_top_logprobs=prefill_top_logprobs,
|
|
decode_token_logprobs=ret["meta_info"]["decode_token_logprobs"],
|
|
decode_top_logprobs=ret["meta_info"]["decode_top_logprobs"],
|
|
)
|
|
else:
|
|
logprobs = None
|
|
|
|
choice_data = CompletionResponseChoice(
|
|
index=0,
|
|
text=text,
|
|
logprobs=logprobs,
|
|
finish_reason=None, # TODO(comaniac): Add finish reason.
|
|
)
|
|
|
|
response = CompletionResponse(
|
|
id=ret["meta_info"]["id"],
|
|
model=request.model,
|
|
choices=[choice_data],
|
|
usage=UsageInfo(
|
|
prompt_tokens=prompt_tokens,
|
|
completion_tokens=completion_tokens,
|
|
total_tokens=prompt_tokens + completion_tokens,
|
|
),
|
|
)
|
|
return response
|
|
|
|
|
|
@app.post("/v1/chat/completions")
|
|
async def v1_chat_completions(raw_request: Request):
|
|
request_json = await raw_request.json()
|
|
request = ChatCompletionRequest(**request_json)
|
|
|
|
# TODO: Validate the request and return HTTPStatus.BAD_REQUEST if invalid.
|
|
assert request.n == 1
|
|
|
|
# Prep the data needed for the underlying GenerateReqInput:
|
|
# - prompt: The full prompt string.
|
|
# - stop: Custom stop tokens.
|
|
# - image_data: None or a list of image strings (URLs or base64 strings).
|
|
# None skips any image processing in GenerateReqInput.
|
|
if not isinstance(request.messages, str):
|
|
# Apply chat template and its stop strings.
|
|
if chat_template_name is None:
|
|
# This flow doesn't support the full OpenAI spec. Verify messages
|
|
# has the right type before proceeding:
|
|
for m in request.messages:
|
|
if not isinstance(m.content, str):
|
|
raise HTTPException(
|
|
status_code=503,
|
|
detail="Structured content requests not supported with "
|
|
"HuggingFace Chat Templates. "
|
|
"Make sure the server specifies a sglang chat template.",
|
|
)
|
|
prompt = tokenizer_manager.tokenizer.apply_chat_template(
|
|
request.messages, tokenize=False, add_generation_prompt=True
|
|
)
|
|
stop = request.stop
|
|
image_data = None
|
|
else:
|
|
conv = generate_chat_conv(request, chat_template_name)
|
|
prompt = conv.get_prompt()
|
|
image_data = conv.image_data
|
|
stop = conv.stop_str or []
|
|
if request.stop:
|
|
if isinstance(request.stop, str):
|
|
stop.append(request.stop)
|
|
else:
|
|
stop.extend(request.stop)
|
|
else:
|
|
# Use the raw prompt and stop strings if the messages is already a string.
|
|
prompt = request.messages
|
|
stop = request.stop
|
|
image_data = None
|
|
|
|
adapted_request = GenerateReqInput(
|
|
text=prompt,
|
|
image_data=image_data,
|
|
sampling_params={
|
|
"temperature": request.temperature,
|
|
"max_new_tokens": request.max_tokens,
|
|
"stop": stop,
|
|
"top_p": request.top_p,
|
|
"presence_penalty": request.presence_penalty,
|
|
"frequency_penalty": request.frequency_penalty,
|
|
"regex": request.regex,
|
|
},
|
|
stream=request.stream,
|
|
)
|
|
adapted_request.post_init()
|
|
|
|
if adapted_request.stream:
|
|
|
|
async def gnerate_stream_resp():
|
|
is_first = True
|
|
|
|
stream_buffer = ""
|
|
async for content in stream_generator(adapted_request):
|
|
if is_first:
|
|
# First chunk with role
|
|
is_first = False
|
|
choice_data = ChatCompletionResponseStreamChoice(
|
|
index=0,
|
|
delta=DeltaMessage(role="assistant"),
|
|
finish_reason=None,
|
|
)
|
|
chunk = ChatCompletionStreamResponse(
|
|
id=content["meta_info"]["id"],
|
|
choices=[choice_data],
|
|
model=request.model,
|
|
)
|
|
yield f"data: {jsonify_pydantic_model(chunk)}\n\n"
|
|
|
|
text = content["text"]
|
|
delta = text[len(stream_buffer) :]
|
|
stream_buffer = text
|
|
choice_data = ChatCompletionResponseStreamChoice(
|
|
index=0, delta=DeltaMessage(content=delta), finish_reason=None
|
|
)
|
|
chunk = ChatCompletionStreamResponse(
|
|
id=content["meta_info"]["id"],
|
|
choices=[choice_data],
|
|
model=request.model,
|
|
)
|
|
yield f"data: {jsonify_pydantic_model(chunk)}\n\n"
|
|
yield "data: [DONE]\n\n"
|
|
|
|
return StreamingResponse(gnerate_stream_resp(), media_type="text/event-stream")
|
|
|
|
# Non-streaming response.
|
|
ret = await generate_request(adapted_request)
|
|
prompt_tokens = ret["meta_info"]["prompt_tokens"]
|
|
completion_tokens = ret["meta_info"]["completion_tokens"]
|
|
choice_data = ChatCompletionResponseChoice(
|
|
index=0,
|
|
message=ChatMessage(role="assistant", content=ret["text"]),
|
|
finish_reason=None, # TODO(comaniac): Add finish reason.
|
|
)
|
|
response = ChatCompletionResponse(
|
|
id=ret["meta_info"]["id"],
|
|
model=request.model,
|
|
choices=[choice_data],
|
|
usage=UsageInfo(
|
|
prompt_tokens=prompt_tokens,
|
|
completion_tokens=completion_tokens,
|
|
total_tokens=prompt_tokens + completion_tokens,
|
|
),
|
|
)
|
|
return response
|
|
|
|
|
|
def launch_server(server_args, pipe_finish_writer):
|
|
global tokenizer_manager
|
|
global chat_template_name
|
|
|
|
# start show time thread
|
|
if server_args.show_time_cost:
|
|
enable_show_time_cost()
|
|
|
|
# disable disk cache if needed
|
|
if server_args.disable_disk_cache:
|
|
disable_cache()
|
|
|
|
# Handle ports
|
|
server_args.port, server_args.additional_ports = handle_port_init(
|
|
server_args.port, server_args.additional_ports, server_args.tp_size
|
|
)
|
|
|
|
port_args = PortArgs(
|
|
tokenizer_port=server_args.additional_ports[0],
|
|
router_port=server_args.additional_ports[1],
|
|
detokenizer_port=server_args.additional_ports[2],
|
|
nccl_port=server_args.additional_ports[3],
|
|
model_rpc_ports=server_args.additional_ports[4:],
|
|
)
|
|
|
|
# Load chat template if needed
|
|
if server_args.chat_template is not None:
|
|
print(f"Use chat template: {server_args.chat_template}")
|
|
if not chat_template_exists(server_args.chat_template):
|
|
if not os.path.exists(server_args.chat_template):
|
|
raise RuntimeError(
|
|
f"Chat template {server_args.chat_template} is not a built-in template name "
|
|
"or a valid chat template file path."
|
|
)
|
|
with open(server_args.chat_template, "r") as filep:
|
|
template = json.load(filep)
|
|
try:
|
|
sep_style = SeparatorStyle[template["sep_style"]]
|
|
except KeyError:
|
|
raise ValueError(
|
|
f"Unknown separator style: {template['sep_style']}"
|
|
) from None
|
|
register_conv_template(
|
|
Conversation(
|
|
name=template["name"],
|
|
system_template=template["system"] + "\n{system_message}",
|
|
system_message=template.get("system_message", ""),
|
|
roles=(template["user"], template["assistant"]),
|
|
sep_style=sep_style,
|
|
sep=template.get("sep", "\n"),
|
|
stop_str=template["stop_str"],
|
|
),
|
|
override=True,
|
|
)
|
|
chat_template_name = template["name"]
|
|
else:
|
|
chat_template_name = server_args.chat_template
|
|
|
|
# Launch processes
|
|
tokenizer_manager = TokenizerManager(server_args, port_args)
|
|
pipe_router_reader, pipe_router_writer = mp.Pipe(duplex=False)
|
|
pipe_detoken_reader, pipe_detoken_writer = mp.Pipe(duplex=False)
|
|
|
|
proc_router = mp.Process(
|
|
target=start_router_process,
|
|
args=(
|
|
server_args,
|
|
port_args,
|
|
pipe_router_writer,
|
|
),
|
|
)
|
|
proc_router.start()
|
|
proc_detoken = mp.Process(
|
|
target=start_detokenizer_process,
|
|
args=(
|
|
server_args,
|
|
port_args,
|
|
pipe_detoken_writer,
|
|
),
|
|
)
|
|
proc_detoken.start()
|
|
|
|
# Wait for the model to finish loading
|
|
router_init_state = pipe_router_reader.recv()
|
|
detoken_init_state = pipe_detoken_reader.recv()
|
|
|
|
if router_init_state != "init ok" or detoken_init_state != "init ok":
|
|
proc_router.kill()
|
|
proc_detoken.kill()
|
|
print("router init state:", router_init_state)
|
|
print("detoken init state:", detoken_init_state)
|
|
sys.exit(1)
|
|
|
|
assert proc_router.is_alive() and proc_detoken.is_alive()
|
|
|
|
if server_args.api_key and server_args.api_key != "":
|
|
app.add_middleware(APIKeyValidatorMiddleware, api_key=server_args.api_key)
|
|
|
|
def _launch_server():
|
|
uvicorn.run(
|
|
app,
|
|
host=server_args.host,
|
|
port=server_args.port,
|
|
log_level=server_args.log_level,
|
|
timeout_keep_alive=5,
|
|
loop="uvloop",
|
|
)
|
|
|
|
def _wait_and_warmup():
|
|
headers = {}
|
|
url = server_args.url()
|
|
if server_args.api_key and server_args.api_key != "":
|
|
headers[API_KEY_HEADER_NAME] = server_args.api_key
|
|
|
|
for _ in range(120):
|
|
time.sleep(0.5)
|
|
try:
|
|
requests.get(url + "/get_model_info", timeout=5, headers=headers)
|
|
break
|
|
except requests.exceptions.RequestException as e:
|
|
pass
|
|
else:
|
|
if pipe_finish_writer is not None:
|
|
pipe_finish_writer.send(str(e))
|
|
else:
|
|
print(e, flush=True)
|
|
return
|
|
|
|
# Warmup
|
|
try:
|
|
# print("Warmup...", flush=True)
|
|
res = requests.post(
|
|
url + "/generate",
|
|
json={
|
|
"text": "Say this is a warmup request.",
|
|
"sampling_params": {
|
|
"temperature": 0,
|
|
"max_new_tokens": 16,
|
|
},
|
|
},
|
|
headers=headers,
|
|
timeout=60,
|
|
)
|
|
# print(f"Warmup done. model response: {res.json()['text']}")
|
|
# print("=" * 20, "Server is ready", "=" * 20, flush=True)
|
|
except requests.exceptions.RequestException as e:
|
|
if pipe_finish_writer is not None:
|
|
pipe_finish_writer.send(str(e))
|
|
else:
|
|
print(e, flush=True)
|
|
return
|
|
|
|
if pipe_finish_writer is not None:
|
|
pipe_finish_writer.send("init ok")
|
|
|
|
t = threading.Thread(target=_wait_and_warmup)
|
|
t.start()
|
|
try:
|
|
_launch_server()
|
|
finally:
|
|
t.join()
|
|
|
|
|
|
class Runtime:
|
|
def __init__(
|
|
self,
|
|
model_path: str,
|
|
tokenizer_path: Optional[str] = None,
|
|
load_format: str = "auto",
|
|
tokenizer_mode: str = "auto",
|
|
trust_remote_code: bool = True,
|
|
mem_fraction_static: float = ServerArgs.mem_fraction_static,
|
|
max_prefill_num_token: int = ServerArgs.max_prefill_num_token,
|
|
context_length: int = ServerArgs.context_length,
|
|
tp_size: int = 1,
|
|
schedule_heuristic: str = "lpm",
|
|
attention_reduce_in_fp32: bool = False,
|
|
random_seed: int = 42,
|
|
log_level: str = "error",
|
|
disable_radix_cache: bool = False,
|
|
enable_flashinfer: bool = False,
|
|
disable_regex_jump_forward: bool = False,
|
|
disable_disk_cache: bool = False,
|
|
api_key: str = "",
|
|
port: Optional[int] = None,
|
|
additional_ports: Optional[Union[List[int], int]] = None,
|
|
):
|
|
host = "127.0.0.1"
|
|
port, additional_ports = handle_port_init(port, additional_ports, tp_size)
|
|
self.server_args = ServerArgs(
|
|
model_path=model_path,
|
|
tokenizer_path=tokenizer_path,
|
|
host=host,
|
|
port=port,
|
|
additional_ports=additional_ports,
|
|
load_format=load_format,
|
|
tokenizer_mode=tokenizer_mode,
|
|
trust_remote_code=trust_remote_code,
|
|
mem_fraction_static=mem_fraction_static,
|
|
max_prefill_num_token=max_prefill_num_token,
|
|
context_length=context_length,
|
|
tp_size=tp_size,
|
|
schedule_heuristic=schedule_heuristic,
|
|
attention_reduce_in_fp32=attention_reduce_in_fp32,
|
|
random_seed=random_seed,
|
|
log_level=log_level,
|
|
disable_radix_cache=disable_radix_cache,
|
|
enable_flashinfer=enable_flashinfer,
|
|
disable_regex_jump_forward=disable_regex_jump_forward,
|
|
disable_disk_cache=disable_disk_cache,
|
|
api_key=api_key,
|
|
)
|
|
|
|
self.url = self.server_args.url()
|
|
self.generate_url = (
|
|
f"http://{self.server_args.host}:{self.server_args.port}/generate"
|
|
)
|
|
|
|
self.pid = None
|
|
pipe_reader, pipe_writer = mp.Pipe(duplex=False)
|
|
proc = mp.Process(target=launch_server, args=(self.server_args, pipe_writer))
|
|
proc.start()
|
|
pipe_writer.close()
|
|
self.pid = proc.pid
|
|
|
|
try:
|
|
init_state = pipe_reader.recv()
|
|
except EOFError:
|
|
init_state = ""
|
|
|
|
if init_state != "init ok":
|
|
self.shutdown()
|
|
raise RuntimeError("Launch failed. Please see the error messages above.")
|
|
|
|
self.endpoint = RuntimeEndpoint(self.url)
|
|
|
|
def shutdown(self):
|
|
if self.pid is not None:
|
|
try:
|
|
parent = psutil.Process(self.pid)
|
|
except psutil.NoSuchProcess:
|
|
return
|
|
children = parent.children(recursive=True)
|
|
for child in children:
|
|
child.kill()
|
|
psutil.wait_procs(children, timeout=5)
|
|
parent.kill()
|
|
parent.wait(timeout=5)
|
|
self.pid = None
|
|
|
|
def get_tokenizer(self):
|
|
return get_tokenizer(
|
|
self.server_args.tokenizer_path,
|
|
tokenizer_mode=self.server_args.tokenizer_mode,
|
|
trust_remote_code=self.server_args.trust_remote_code,
|
|
)
|
|
|
|
async def add_request(
|
|
self,
|
|
prompt: str,
|
|
sampling_params,
|
|
) -> None:
|
|
json_data = {
|
|
"text": prompt,
|
|
"sampling_params": sampling_params,
|
|
"stream": True,
|
|
}
|
|
|
|
pos = 0
|
|
|
|
timeout = aiohttp.ClientTimeout(total=3 * 3600)
|
|
async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
|
|
async with session.post(self.generate_url, json=json_data) as response:
|
|
async for chunk, _ in response.content.iter_chunks():
|
|
chunk = chunk.decode("utf-8")
|
|
if chunk and chunk.startswith("data:"):
|
|
if chunk == "data: [DONE]\n\n":
|
|
break
|
|
data = json.loads(chunk[5:].strip("\n"))
|
|
cur = data["text"][pos:]
|
|
if cur:
|
|
yield cur
|
|
pos += len(cur)
|
|
|
|
def __del__(self):
|
|
self.shutdown()
|