Support penalty in overlap mode; return logprob with chunked prefill; improve benchmark scripts (#3988)

Co-authored-by: SangBin Cho <rkooo567@gmail.com>
Co-authored-by: dhou-xai <dhou@x.ai>
Co-authored-by: Hanming Lu <hanming_lu@berkeley.edu>
This commit is contained in:
Lianmin Zheng
2025-03-03 00:12:04 -08:00
parent 0194948fd9
commit ac2387279e
86 changed files with 4116 additions and 2015 deletions

View File

@@ -25,11 +25,14 @@ import os
import threading
import time
from http import HTTPStatus
from typing import AsyncIterator, Dict, Optional
from typing import AsyncIterator, Callable, Dict, Optional
# Fix a bug of Python threading
setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
from contextlib import asynccontextmanager
import numpy as np
import orjson
import requests
import uvicorn
@@ -49,8 +52,10 @@ from sglang.srt.managers.io_struct import (
InitWeightsUpdateGroupReqInput,
OpenSessionReqInput,
ParseFunctionCallReq,
ProfileReqInput,
ReleaseMemoryOccupationReqInput,
ResumeMemoryOccupationReqInput,
SetInternalStateReq,
UpdateWeightFromDiskReqInput,
UpdateWeightsFromDistributedReqInput,
VertexGenerateReqInput,
@@ -78,22 +83,13 @@ from sglang.srt.utils import (
kill_process_tree,
set_uvicorn_logging_configs,
)
from sglang.srt.warmup import execute_warmups
from sglang.utils import get_exception_traceback
from sglang.version import __version__
logger = logging.getLogger(__name__)
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
# Fast API
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Store global states
@dataclasses.dataclass
@@ -110,6 +106,34 @@ def set_global_state(global_state: _GlobalState):
_global_state = global_state
@asynccontextmanager
async def lifespan(fast_api_app: FastAPI):
server_args: ServerArgs = fast_api_app.server_args
if server_args.warmups is not None:
await execute_warmups(
server_args.warmups.split(","), _global_state.tokenizer_manager
)
logger.info("Warmup ended")
warmup_thread = getattr(fast_api_app, "warmup_thread", None)
if warmup_thread is not None:
warmup_thread.start()
yield
# Fast API
app = FastAPI(lifespan=lifespan)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
HEALTH_CHECK_TIMEOUT = int(os.getenv("SGLANG_HEALTH_CHECK_TIMEOUT", 20))
##### Native API endpoints #####
@@ -123,24 +147,48 @@ async def health() -> Response:
async def health_generate(request: Request) -> Response:
"""Check the health of the inference server by generating one token."""
sampling_params = {"max_new_tokens": 1, "temperature": 0.7}
sampling_params = {"max_new_tokens": 1, "temperature": 0.0}
rid = f"HEALTH_CHECK_{time.time()}"
if _global_state.tokenizer_manager.is_generation:
if _global_state.tokenizer_manager.is_image_gen:
raise NotImplementedError()
elif _global_state.tokenizer_manager.is_generation:
gri = GenerateReqInput(
input_ids=[0], sampling_params=sampling_params, log_metrics=False
rid=rid,
input_ids=[0],
sampling_params=sampling_params,
log_metrics=False,
)
else:
gri = EmbeddingReqInput(
input_ids=[0], sampling_params=sampling_params, log_metrics=False
rid=rid, input_ids=[0], sampling_params=sampling_params, log_metrics=False
)
try:
async def gen():
async for _ in _global_state.tokenizer_manager.generate_request(gri, request):
break
return Response(status_code=200)
except Exception as e:
logger.exception(e)
return Response(status_code=503)
tic = time.time()
task = asyncio.create_task(gen())
while time.time() < tic + HEALTH_CHECK_TIMEOUT:
await asyncio.sleep(1)
if _global_state.tokenizer_manager.last_receive_tstamp > tic:
task.cancel()
_global_state.tokenizer_manager.rid_to_state.pop(rid, None)
return Response(status_code=200)
task.cancel()
tic_time = time.strftime("%H:%M:%S", time.localtime(tic))
last_receive_time = time.strftime(
"%H:%M:%S", time.localtime(_global_state.tokenizer_manager.last_receive_tstamp)
)
logger.error(
f"Health check failed. Server couldn't get a response from detokenizer for last "
f"{HEALTH_CHECK_TIMEOUT} seconds. tic start time: {tic_time}. "
f"last_heartbeat time: {last_receive_time}"
)
_global_state.tokenizer_manager.rid_to_state.pop(rid, None)
return Response(status_code=503)
@app.get("/get_model_info")
@@ -156,13 +204,21 @@ async def get_model_info():
@app.get("/get_server_info")
async def get_server_info():
internal_states = await _global_state.tokenizer_manager.get_internal_state()
return {
**dataclasses.asdict(_global_state.tokenizer_manager.server_args),
**_global_state.scheduler_info,
**internal_states,
"version": __version__,
}
@app.api_route("/set_internal_state", methods=["POST", "PUT"])
async def set_internal_state(obj: SetInternalStateReq, request: Request):
res = await _global_state.tokenizer_manager.set_internal_state(obj)
return res
# fastapi implicitly converts json in the request to obj (dataclass)
@app.api_route("/generate", methods=["POST", "PUT"])
async def generate_request(obj: GenerateReqInput, request: Request):
@@ -179,6 +235,7 @@ async def generate_request(obj: GenerateReqInput, request: Request):
) + b"\n\n"
except ValueError as e:
out = {"error": {"message": str(e)}}
logger.error(f"Error: {e}")
yield b"data: " + orjson.dumps(
out, option=orjson.OPT_NON_STR_KEYS
) + b"\n\n"
@@ -236,9 +293,14 @@ async def flush_cache():
@app.api_route("/start_profile", methods=["GET", "POST"])
async def start_profile_async():
async def start_profile_async(obj: Optional[ProfileReqInput] = None):
"""Start profiling."""
_global_state.tokenizer_manager.start_profile()
if obj is None:
obj = ProfileReqInput()
await _global_state.tokenizer_manager.start_profile(
obj.output_dir, obj.num_steps, obj.activities
)
return Response(
content="Start profiling.\n",
status_code=200,
@@ -257,11 +319,15 @@ async def stop_profile_async():
@app.post("/update_weights_from_disk")
async def update_weights_from_disk(obj: UpdateWeightFromDiskReqInput, request: Request):
"""Update the weights from disk in-place without re-launching the server."""
success, message = await _global_state.tokenizer_manager.update_weights_from_disk(
obj, request
"""Update the weights from disk inplace without re-launching the server."""
success, message, num_paused_requests = (
await _global_state.tokenizer_manager.update_weights_from_disk(obj, request)
)
content = {"success": success, "message": message}
content = {
"success": success,
"message": message,
"num_paused_requests": num_paused_requests,
}
if success:
return ORJSONResponse(
content,
@@ -323,7 +389,7 @@ async def get_weights_by_name(obj: GetWeightsByNameReqInput, request: Request):
async def release_memory_occupation(
obj: ReleaseMemoryOccupationReqInput, request: Request
):
"""Release GPU occupation temporarily"""
"""Release GPU memory occupation temporarily."""
try:
await _global_state.tokenizer_manager.release_memory_occupation(obj, request)
except Exception as e:
@@ -334,7 +400,7 @@ async def release_memory_occupation(
async def resume_memory_occupation(
obj: ResumeMemoryOccupationReqInput, request: Request
):
"""Resume GPU occupation"""
"""Resume GPU memory occupation."""
try:
await _global_state.tokenizer_manager.resume_memory_occupation(obj, request)
except Exception as e:
@@ -357,7 +423,7 @@ async def open_session(obj: OpenSessionReqInput, request: Request):
@app.api_route("/close_session", methods=["GET", "POST"])
async def close_session(obj: CloseSessionReqInput, request: Request):
"""Close the session"""
"""Close the session."""
try:
await _global_state.tokenizer_manager.close_session(obj, request)
return Response(status_code=200)
@@ -367,7 +433,7 @@ async def close_session(obj: CloseSessionReqInput, request: Request):
@app.api_route("/configure_logging", methods=["GET", "POST"])
async def configure_logging(obj: ConfigureLoggingReq, request: Request):
"""Close the session"""
"""Configure the request logging options."""
_global_state.tokenizer_manager.configure_logging(obj)
return Response(status_code=200)
@@ -511,6 +577,7 @@ def _create_error_response(e):
def launch_server(
server_args: ServerArgs,
pipe_finish_writer: Optional[multiprocessing.connection.Connection] = None,
launch_callback: Optional[Callable[[], None]] = None,
):
"""
Launch SRT (SGLang Runtime) Server.
@@ -544,21 +611,23 @@ def launch_server(
add_prometheus_middleware(app)
enable_func_timer()
# Send a warmup request
t = threading.Thread(
# Send a warmup request - we will create the thread launch it
# in the lifespan after all other warmups have fired.
warmup_thread = threading.Thread(
target=_wait_and_warmup,
args=(
server_args,
pipe_finish_writer,
_global_state.tokenizer_manager.image_token_id,
launch_callback,
),
)
t.start()
app.warmup_thread = warmup_thread
try:
# Update logging configs
set_uvicorn_logging_configs()
app.server_args = server_args
# Listen for HTTP requests
uvicorn.run(
app,
@@ -569,10 +638,15 @@ def launch_server(
loop="uvloop",
)
finally:
t.join()
warmup_thread.join()
def _wait_and_warmup(server_args, pipe_finish_writer, image_token_text):
def _wait_and_warmup(
server_args: ServerArgs,
pipe_finish_writer: Optional[multiprocessing.connection.Connection],
image_token_text: str,
launch_callback: Optional[Callable[[], None]] = None,
):
headers = {}
url = server_args.url()
if server_args.api_key:
@@ -614,8 +688,16 @@ def _wait_and_warmup(server_args, pipe_finish_writer, image_token_text):
else:
json_data["text"] = "The capital city of France is"
# Debug dumping
if server_args.debug_tensor_dump_input_file:
json_data.pop("text", None)
json_data["input_ids"] = np.load(
server_args.debug_tensor_dump_input_file
).tolist()
json_data["sampling_params"]["max_new_tokens"] = 0
try:
for _ in range(server_args.dp_size):
for i in range(server_args.dp_size):
res = requests.post(
url + request_name,
json=json_data,
@@ -640,3 +722,9 @@ def _wait_and_warmup(server_args, pipe_finish_writer, image_token_text):
if server_args.delete_ckpt_after_loading:
delete_directory(server_args.model_path)
if server_args.debug_tensor_dump_input_file:
kill_process_tree(os.getpid())
if launch_callback is not None:
launch_callback()