Clean up server_args, triton cache manager (#8332)

This commit is contained in:
Lianmin Zheng
2025-07-25 14:14:51 -07:00
committed by GitHub
parent f8260f2539
commit ed2e313eb6
12 changed files with 128 additions and 204 deletions

View File

@@ -71,7 +71,6 @@ from sglang.srt.utils import (
is_cuda,
kill_process_tree,
launch_dummy_health_check_server,
maybe_set_triton_cache_manager,
prepare_model_and_tokenizer,
set_prometheus_multiproc_dir,
set_ulimit,
@@ -637,11 +636,6 @@ def _set_envs_and_config(server_args: ServerArgs):
# Set ulimit
set_ulimit()
# Fix triton bugs
if server_args.tp_size * server_args.dp_size > 1:
# FIXME: remove this after https://github.com/triton-lang/triton/pull/4295 is used as a dependency.
maybe_set_triton_cache_manager()
# Check flashinfer version
if server_args.attention_backend == "flashinfer":
assert_pkg_version(

View File

@@ -107,6 +107,8 @@ from sglang.version import __version__
logger = logging.getLogger(__name__)
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
HEALTH_CHECK_TIMEOUT = int(os.getenv("SGLANG_HEALTH_CHECK_TIMEOUT", 20))
# Store global states
@dataclasses.dataclass
@@ -212,9 +214,6 @@ async def validate_json_request(raw_request: Request):
)
HEALTH_CHECK_TIMEOUT = int(os.getenv("SGLANG_HEALTH_CHECK_TIMEOUT", 20))
##### Native API endpoints #####
@@ -807,6 +806,24 @@ async def retrieve_model(model: str):
)
@app.post("/v1/score", dependencies=[Depends(validate_json_request)])
async def v1_score_request(request: ScoringRequest, raw_request: Request):
"""Endpoint for the decoder-only scoring API. See Engine.score() for detailed documentation."""
return await raw_request.app.state.openai_serving_score.handle_request(
request, raw_request
)
@app.api_route(
"/v1/rerank", methods=["POST", "PUT"], dependencies=[Depends(validate_json_request)]
)
async def v1_rerank_request(request: V1RerankReqInput, raw_request: Request):
"""Endpoint for reranking documents based on query relevance."""
return await raw_request.app.state.openai_serving_rerank.handle_request(
request, raw_request
)
## SageMaker API
@app.get("/ping")
async def sagemaker_health() -> Response:
@@ -852,24 +869,6 @@ async def vertex_generate(vertex_req: VertexGenerateReqInput, raw_request: Reque
return ORJSONResponse({"predictions": ret})
@app.post("/v1/score", dependencies=[Depends(validate_json_request)])
async def v1_score_request(request: ScoringRequest, raw_request: Request):
"""Endpoint for the decoder-only scoring API. See Engine.score() for detailed documentation."""
return await raw_request.app.state.openai_serving_score.handle_request(
request, raw_request
)
@app.api_route(
"/v1/rerank", methods=["POST", "PUT"], dependencies=[Depends(validate_json_request)]
)
async def v1_rerank_request(request: V1RerankReqInput, raw_request: Request):
"""Endpoint for reranking documents based on query relevance."""
return await raw_request.app.state.openai_serving_rerank.handle_request(
request, raw_request
)
def _create_error_response(e):
return ORJSONResponse(
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
@@ -916,15 +915,6 @@ def launch_server(
add_prometheus_middleware(app)
enable_func_timer()
image_token_text = None
if (
tokenizer_manager.image_token_id is not None
and not server_args.skip_tokenizer_init
):
image_token_text = tokenizer_manager.tokenizer.decode(
[tokenizer_manager.image_token_id]
)
# Send a warmup request - we will create the thread launch it
# in the lifespan after all other warmups have fired.
warmup_thread = threading.Thread(
@@ -932,7 +922,6 @@ def launch_server(
args=(
server_args,
pipe_finish_writer,
image_token_text,
launch_callback,
),
)
@@ -1066,7 +1055,6 @@ def _execute_server_warmup(
def _wait_and_warmup(
server_args: ServerArgs,
pipe_finish_writer: Optional[multiprocessing.connection.Connection],
image_token_text: str,
launch_callback: Optional[Callable[[], None]] = None,
):
if not server_args.skip_server_warmup: