Support penalty in overlap mode; return logprob with chunked prefill; improve benchmark scripts (#3988)
Co-authored-by: SangBin Cho <rkooo567@gmail.com> Co-authored-by: dhou-xai <dhou@x.ai> Co-authored-by: Hanming Lu <hanming_lu@berkeley.edu>
This commit is contained in:
@@ -25,11 +25,14 @@ import os
|
||||
import threading
|
||||
import time
|
||||
from http import HTTPStatus
|
||||
from typing import AsyncIterator, Dict, Optional
|
||||
from typing import AsyncIterator, Callable, Dict, Optional
|
||||
|
||||
# Fix a bug of Python threading
|
||||
setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
|
||||
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
import numpy as np
|
||||
import orjson
|
||||
import requests
|
||||
import uvicorn
|
||||
@@ -49,8 +52,10 @@ from sglang.srt.managers.io_struct import (
|
||||
InitWeightsUpdateGroupReqInput,
|
||||
OpenSessionReqInput,
|
||||
ParseFunctionCallReq,
|
||||
ProfileReqInput,
|
||||
ReleaseMemoryOccupationReqInput,
|
||||
ResumeMemoryOccupationReqInput,
|
||||
SetInternalStateReq,
|
||||
UpdateWeightFromDiskReqInput,
|
||||
UpdateWeightsFromDistributedReqInput,
|
||||
VertexGenerateReqInput,
|
||||
@@ -78,22 +83,13 @@ from sglang.srt.utils import (
|
||||
kill_process_tree,
|
||||
set_uvicorn_logging_configs,
|
||||
)
|
||||
from sglang.srt.warmup import execute_warmups
|
||||
from sglang.utils import get_exception_traceback
|
||||
from sglang.version import __version__
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
||||
|
||||
# Fast API
|
||||
app = FastAPI()
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
|
||||
# Store global states
|
||||
@dataclasses.dataclass
|
||||
@@ -110,6 +106,34 @@ def set_global_state(global_state: _GlobalState):
|
||||
_global_state = global_state
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(fast_api_app: FastAPI):
|
||||
server_args: ServerArgs = fast_api_app.server_args
|
||||
if server_args.warmups is not None:
|
||||
await execute_warmups(
|
||||
server_args.warmups.split(","), _global_state.tokenizer_manager
|
||||
)
|
||||
logger.info("Warmup ended")
|
||||
|
||||
warmup_thread = getattr(fast_api_app, "warmup_thread", None)
|
||||
if warmup_thread is not None:
|
||||
warmup_thread.start()
|
||||
yield
|
||||
|
||||
|
||||
# Fast API
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
HEALTH_CHECK_TIMEOUT = int(os.getenv("SGLANG_HEALTH_CHECK_TIMEOUT", 20))
|
||||
|
||||
|
||||
##### Native API endpoints #####
|
||||
|
||||
|
||||
@@ -123,24 +147,48 @@ async def health() -> Response:
|
||||
async def health_generate(request: Request) -> Response:
|
||||
"""Check the health of the inference server by generating one token."""
|
||||
|
||||
sampling_params = {"max_new_tokens": 1, "temperature": 0.7}
|
||||
sampling_params = {"max_new_tokens": 1, "temperature": 0.0}
|
||||
rid = f"HEALTH_CHECK_{time.time()}"
|
||||
|
||||
if _global_state.tokenizer_manager.is_generation:
|
||||
if _global_state.tokenizer_manager.is_image_gen:
|
||||
raise NotImplementedError()
|
||||
elif _global_state.tokenizer_manager.is_generation:
|
||||
gri = GenerateReqInput(
|
||||
input_ids=[0], sampling_params=sampling_params, log_metrics=False
|
||||
rid=rid,
|
||||
input_ids=[0],
|
||||
sampling_params=sampling_params,
|
||||
log_metrics=False,
|
||||
)
|
||||
else:
|
||||
gri = EmbeddingReqInput(
|
||||
input_ids=[0], sampling_params=sampling_params, log_metrics=False
|
||||
rid=rid, input_ids=[0], sampling_params=sampling_params, log_metrics=False
|
||||
)
|
||||
|
||||
try:
|
||||
async def gen():
|
||||
async for _ in _global_state.tokenizer_manager.generate_request(gri, request):
|
||||
break
|
||||
return Response(status_code=200)
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
return Response(status_code=503)
|
||||
|
||||
tic = time.time()
|
||||
task = asyncio.create_task(gen())
|
||||
while time.time() < tic + HEALTH_CHECK_TIMEOUT:
|
||||
await asyncio.sleep(1)
|
||||
if _global_state.tokenizer_manager.last_receive_tstamp > tic:
|
||||
task.cancel()
|
||||
_global_state.tokenizer_manager.rid_to_state.pop(rid, None)
|
||||
return Response(status_code=200)
|
||||
|
||||
task.cancel()
|
||||
tic_time = time.strftime("%H:%M:%S", time.localtime(tic))
|
||||
last_receive_time = time.strftime(
|
||||
"%H:%M:%S", time.localtime(_global_state.tokenizer_manager.last_receive_tstamp)
|
||||
)
|
||||
logger.error(
|
||||
f"Health check failed. Server couldn't get a response from detokenizer for last "
|
||||
f"{HEALTH_CHECK_TIMEOUT} seconds. tic start time: {tic_time}. "
|
||||
f"last_heartbeat time: {last_receive_time}"
|
||||
)
|
||||
_global_state.tokenizer_manager.rid_to_state.pop(rid, None)
|
||||
return Response(status_code=503)
|
||||
|
||||
|
||||
@app.get("/get_model_info")
|
||||
@@ -156,13 +204,21 @@ async def get_model_info():
|
||||
|
||||
@app.get("/get_server_info")
|
||||
async def get_server_info():
|
||||
internal_states = await _global_state.tokenizer_manager.get_internal_state()
|
||||
return {
|
||||
**dataclasses.asdict(_global_state.tokenizer_manager.server_args),
|
||||
**_global_state.scheduler_info,
|
||||
**internal_states,
|
||||
"version": __version__,
|
||||
}
|
||||
|
||||
|
||||
@app.api_route("/set_internal_state", methods=["POST", "PUT"])
|
||||
async def set_internal_state(obj: SetInternalStateReq, request: Request):
|
||||
res = await _global_state.tokenizer_manager.set_internal_state(obj)
|
||||
return res
|
||||
|
||||
|
||||
# fastapi implicitly converts json in the request to obj (dataclass)
|
||||
@app.api_route("/generate", methods=["POST", "PUT"])
|
||||
async def generate_request(obj: GenerateReqInput, request: Request):
|
||||
@@ -179,6 +235,7 @@ async def generate_request(obj: GenerateReqInput, request: Request):
|
||||
) + b"\n\n"
|
||||
except ValueError as e:
|
||||
out = {"error": {"message": str(e)}}
|
||||
logger.error(f"Error: {e}")
|
||||
yield b"data: " + orjson.dumps(
|
||||
out, option=orjson.OPT_NON_STR_KEYS
|
||||
) + b"\n\n"
|
||||
@@ -236,9 +293,14 @@ async def flush_cache():
|
||||
|
||||
|
||||
@app.api_route("/start_profile", methods=["GET", "POST"])
|
||||
async def start_profile_async():
|
||||
async def start_profile_async(obj: Optional[ProfileReqInput] = None):
|
||||
"""Start profiling."""
|
||||
_global_state.tokenizer_manager.start_profile()
|
||||
if obj is None:
|
||||
obj = ProfileReqInput()
|
||||
|
||||
await _global_state.tokenizer_manager.start_profile(
|
||||
obj.output_dir, obj.num_steps, obj.activities
|
||||
)
|
||||
return Response(
|
||||
content="Start profiling.\n",
|
||||
status_code=200,
|
||||
@@ -257,11 +319,15 @@ async def stop_profile_async():
|
||||
|
||||
@app.post("/update_weights_from_disk")
|
||||
async def update_weights_from_disk(obj: UpdateWeightFromDiskReqInput, request: Request):
|
||||
"""Update the weights from disk in-place without re-launching the server."""
|
||||
success, message = await _global_state.tokenizer_manager.update_weights_from_disk(
|
||||
obj, request
|
||||
"""Update the weights from disk inplace without re-launching the server."""
|
||||
success, message, num_paused_requests = (
|
||||
await _global_state.tokenizer_manager.update_weights_from_disk(obj, request)
|
||||
)
|
||||
content = {"success": success, "message": message}
|
||||
content = {
|
||||
"success": success,
|
||||
"message": message,
|
||||
"num_paused_requests": num_paused_requests,
|
||||
}
|
||||
if success:
|
||||
return ORJSONResponse(
|
||||
content,
|
||||
@@ -323,7 +389,7 @@ async def get_weights_by_name(obj: GetWeightsByNameReqInput, request: Request):
|
||||
async def release_memory_occupation(
|
||||
obj: ReleaseMemoryOccupationReqInput, request: Request
|
||||
):
|
||||
"""Release GPU occupation temporarily"""
|
||||
"""Release GPU memory occupation temporarily."""
|
||||
try:
|
||||
await _global_state.tokenizer_manager.release_memory_occupation(obj, request)
|
||||
except Exception as e:
|
||||
@@ -334,7 +400,7 @@ async def release_memory_occupation(
|
||||
async def resume_memory_occupation(
|
||||
obj: ResumeMemoryOccupationReqInput, request: Request
|
||||
):
|
||||
"""Resume GPU occupation"""
|
||||
"""Resume GPU memory occupation."""
|
||||
try:
|
||||
await _global_state.tokenizer_manager.resume_memory_occupation(obj, request)
|
||||
except Exception as e:
|
||||
@@ -357,7 +423,7 @@ async def open_session(obj: OpenSessionReqInput, request: Request):
|
||||
|
||||
@app.api_route("/close_session", methods=["GET", "POST"])
|
||||
async def close_session(obj: CloseSessionReqInput, request: Request):
|
||||
"""Close the session"""
|
||||
"""Close the session."""
|
||||
try:
|
||||
await _global_state.tokenizer_manager.close_session(obj, request)
|
||||
return Response(status_code=200)
|
||||
@@ -367,7 +433,7 @@ async def close_session(obj: CloseSessionReqInput, request: Request):
|
||||
|
||||
@app.api_route("/configure_logging", methods=["GET", "POST"])
|
||||
async def configure_logging(obj: ConfigureLoggingReq, request: Request):
|
||||
"""Close the session"""
|
||||
"""Configure the request logging options."""
|
||||
_global_state.tokenizer_manager.configure_logging(obj)
|
||||
return Response(status_code=200)
|
||||
|
||||
@@ -511,6 +577,7 @@ def _create_error_response(e):
|
||||
def launch_server(
|
||||
server_args: ServerArgs,
|
||||
pipe_finish_writer: Optional[multiprocessing.connection.Connection] = None,
|
||||
launch_callback: Optional[Callable[[], None]] = None,
|
||||
):
|
||||
"""
|
||||
Launch SRT (SGLang Runtime) Server.
|
||||
@@ -544,21 +611,23 @@ def launch_server(
|
||||
add_prometheus_middleware(app)
|
||||
enable_func_timer()
|
||||
|
||||
# Send a warmup request
|
||||
t = threading.Thread(
|
||||
# Send a warmup request - we will create the thread launch it
|
||||
# in the lifespan after all other warmups have fired.
|
||||
warmup_thread = threading.Thread(
|
||||
target=_wait_and_warmup,
|
||||
args=(
|
||||
server_args,
|
||||
pipe_finish_writer,
|
||||
_global_state.tokenizer_manager.image_token_id,
|
||||
launch_callback,
|
||||
),
|
||||
)
|
||||
t.start()
|
||||
app.warmup_thread = warmup_thread
|
||||
|
||||
try:
|
||||
# Update logging configs
|
||||
set_uvicorn_logging_configs()
|
||||
|
||||
app.server_args = server_args
|
||||
# Listen for HTTP requests
|
||||
uvicorn.run(
|
||||
app,
|
||||
@@ -569,10 +638,15 @@ def launch_server(
|
||||
loop="uvloop",
|
||||
)
|
||||
finally:
|
||||
t.join()
|
||||
warmup_thread.join()
|
||||
|
||||
|
||||
def _wait_and_warmup(server_args, pipe_finish_writer, image_token_text):
|
||||
def _wait_and_warmup(
|
||||
server_args: ServerArgs,
|
||||
pipe_finish_writer: Optional[multiprocessing.connection.Connection],
|
||||
image_token_text: str,
|
||||
launch_callback: Optional[Callable[[], None]] = None,
|
||||
):
|
||||
headers = {}
|
||||
url = server_args.url()
|
||||
if server_args.api_key:
|
||||
@@ -614,8 +688,16 @@ def _wait_and_warmup(server_args, pipe_finish_writer, image_token_text):
|
||||
else:
|
||||
json_data["text"] = "The capital city of France is"
|
||||
|
||||
# Debug dumping
|
||||
if server_args.debug_tensor_dump_input_file:
|
||||
json_data.pop("text", None)
|
||||
json_data["input_ids"] = np.load(
|
||||
server_args.debug_tensor_dump_input_file
|
||||
).tolist()
|
||||
json_data["sampling_params"]["max_new_tokens"] = 0
|
||||
|
||||
try:
|
||||
for _ in range(server_args.dp_size):
|
||||
for i in range(server_args.dp_size):
|
||||
res = requests.post(
|
||||
url + request_name,
|
||||
json=json_data,
|
||||
@@ -640,3 +722,9 @@ def _wait_and_warmup(server_args, pipe_finish_writer, image_token_text):
|
||||
|
||||
if server_args.delete_ckpt_after_loading:
|
||||
delete_directory(server_args.model_path)
|
||||
|
||||
if server_args.debug_tensor_dump_input_file:
|
||||
kill_process_tree(os.getpid())
|
||||
|
||||
if launch_callback is not None:
|
||||
launch_callback()
|
||||
|
||||
Reference in New Issue
Block a user