Clean up server_args, triton cache manager (#8332)
This commit is contained in:
@@ -71,7 +71,6 @@ from sglang.srt.utils import (
|
||||
is_cuda,
|
||||
kill_process_tree,
|
||||
launch_dummy_health_check_server,
|
||||
maybe_set_triton_cache_manager,
|
||||
prepare_model_and_tokenizer,
|
||||
set_prometheus_multiproc_dir,
|
||||
set_ulimit,
|
||||
@@ -637,11 +636,6 @@ def _set_envs_and_config(server_args: ServerArgs):
|
||||
# Set ulimit
|
||||
set_ulimit()
|
||||
|
||||
# Fix triton bugs
|
||||
if server_args.tp_size * server_args.dp_size > 1:
|
||||
# FIXME: remove this after https://github.com/triton-lang/triton/pull/4295 is used as a dependency.
|
||||
maybe_set_triton_cache_manager()
|
||||
|
||||
# Check flashinfer version
|
||||
if server_args.attention_backend == "flashinfer":
|
||||
assert_pkg_version(
|
||||
|
||||
@@ -107,6 +107,8 @@ from sglang.version import __version__
|
||||
logger = logging.getLogger(__name__)
|
||||
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
||||
|
||||
HEALTH_CHECK_TIMEOUT = int(os.getenv("SGLANG_HEALTH_CHECK_TIMEOUT", 20))
|
||||
|
||||
|
||||
# Store global states
|
||||
@dataclasses.dataclass
|
||||
@@ -212,9 +214,6 @@ async def validate_json_request(raw_request: Request):
|
||||
)
|
||||
|
||||
|
||||
HEALTH_CHECK_TIMEOUT = int(os.getenv("SGLANG_HEALTH_CHECK_TIMEOUT", 20))
|
||||
|
||||
|
||||
##### Native API endpoints #####
|
||||
|
||||
|
||||
@@ -807,6 +806,24 @@ async def retrieve_model(model: str):
|
||||
)
|
||||
|
||||
|
||||
@app.post("/v1/score", dependencies=[Depends(validate_json_request)])
|
||||
async def v1_score_request(request: ScoringRequest, raw_request: Request):
|
||||
"""Endpoint for the decoder-only scoring API. See Engine.score() for detailed documentation."""
|
||||
return await raw_request.app.state.openai_serving_score.handle_request(
|
||||
request, raw_request
|
||||
)
|
||||
|
||||
|
||||
@app.api_route(
|
||||
"/v1/rerank", methods=["POST", "PUT"], dependencies=[Depends(validate_json_request)]
|
||||
)
|
||||
async def v1_rerank_request(request: V1RerankReqInput, raw_request: Request):
|
||||
"""Endpoint for reranking documents based on query relevance."""
|
||||
return await raw_request.app.state.openai_serving_rerank.handle_request(
|
||||
request, raw_request
|
||||
)
|
||||
|
||||
|
||||
## SageMaker API
|
||||
@app.get("/ping")
|
||||
async def sagemaker_health() -> Response:
|
||||
@@ -852,24 +869,6 @@ async def vertex_generate(vertex_req: VertexGenerateReqInput, raw_request: Reque
|
||||
return ORJSONResponse({"predictions": ret})
|
||||
|
||||
|
||||
@app.post("/v1/score", dependencies=[Depends(validate_json_request)])
|
||||
async def v1_score_request(request: ScoringRequest, raw_request: Request):
|
||||
"""Endpoint for the decoder-only scoring API. See Engine.score() for detailed documentation."""
|
||||
return await raw_request.app.state.openai_serving_score.handle_request(
|
||||
request, raw_request
|
||||
)
|
||||
|
||||
|
||||
@app.api_route(
|
||||
"/v1/rerank", methods=["POST", "PUT"], dependencies=[Depends(validate_json_request)]
|
||||
)
|
||||
async def v1_rerank_request(request: V1RerankReqInput, raw_request: Request):
|
||||
"""Endpoint for reranking documents based on query relevance."""
|
||||
return await raw_request.app.state.openai_serving_rerank.handle_request(
|
||||
request, raw_request
|
||||
)
|
||||
|
||||
|
||||
def _create_error_response(e):
|
||||
return ORJSONResponse(
|
||||
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
|
||||
@@ -916,15 +915,6 @@ def launch_server(
|
||||
add_prometheus_middleware(app)
|
||||
enable_func_timer()
|
||||
|
||||
image_token_text = None
|
||||
if (
|
||||
tokenizer_manager.image_token_id is not None
|
||||
and not server_args.skip_tokenizer_init
|
||||
):
|
||||
image_token_text = tokenizer_manager.tokenizer.decode(
|
||||
[tokenizer_manager.image_token_id]
|
||||
)
|
||||
|
||||
# Send a warmup request - we will create the thread launch it
|
||||
# in the lifespan after all other warmups have fired.
|
||||
warmup_thread = threading.Thread(
|
||||
@@ -932,7 +922,6 @@ def launch_server(
|
||||
args=(
|
||||
server_args,
|
||||
pipe_finish_writer,
|
||||
image_token_text,
|
||||
launch_callback,
|
||||
),
|
||||
)
|
||||
@@ -1066,7 +1055,6 @@ def _execute_server_warmup(
|
||||
def _wait_and_warmup(
|
||||
server_args: ServerArgs,
|
||||
pipe_finish_writer: Optional[multiprocessing.connection.Connection],
|
||||
image_token_text: str,
|
||||
launch_callback: Optional[Callable[[], None]] = None,
|
||||
):
|
||||
if not server_args.skip_server_warmup:
|
||||
|
||||
Reference in New Issue
Block a user