Improve: Rename TokenizerManager to StdOrchestrator (#3116)

This commit is contained in:
fzyzcjy
2025-02-23 16:30:58 +08:00
committed by GitHub
parent 3f41b18455
commit 45360b2fa9
11 changed files with 116 additions and 130 deletions

View File

@@ -54,7 +54,6 @@ from sglang.srt.managers.io_struct import (
UpdateWeightFromDiskReqInput,
UpdateWeightsFromDistributedReqInput,
)
from sglang.srt.managers.tokenizer_manager import TokenizerManager
from sglang.srt.metrics.func_timer import enable_func_timer
from sglang.srt.openai_api.adapter import (
v1_batches,
@@ -69,6 +68,7 @@ from sglang.srt.openai_api.adapter import (
v1_retrieve_file_content,
)
from sglang.srt.openai_api.protocol import ModelCard, ModelList
from sglang.srt.orchestration.std.orchestrator import StdOrchestrator
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import (
add_api_key_middleware,
@@ -97,7 +97,7 @@ app.add_middleware(
# Store global states
@dataclasses.dataclass
class _GlobalState:
tokenizer_manager: TokenizerManager
orchestrator: StdOrchestrator
scheduler_info: Dict
@@ -124,7 +124,7 @@ async def health_generate(request: Request) -> Response:
sampling_params = {"max_new_tokens": 1, "temperature": 0.7}
if _global_state.tokenizer_manager.is_generation:
if _global_state.orchestrator.is_generation:
gri = GenerateReqInput(
input_ids=[0], sampling_params=sampling_params, log_metrics=False
)
@@ -134,7 +134,7 @@ async def health_generate(request: Request) -> Response:
)
try:
async for _ in _global_state.tokenizer_manager.generate_request(gri, request):
async for _ in _global_state.orchestrator.generate_request(gri, request):
break
return Response(status_code=200)
except Exception as e:
@@ -146,9 +146,9 @@ async def health_generate(request: Request) -> Response:
async def get_model_info():
"""Get the model information."""
result = {
"model_path": _global_state.tokenizer_manager.model_path,
"tokenizer_path": _global_state.tokenizer_manager.server_args.tokenizer_path,
"is_generation": _global_state.tokenizer_manager.is_generation,
"model_path": _global_state.orchestrator.model_path,
"tokenizer_path": _global_state.orchestrator.server_args.tokenizer_path,
"is_generation": _global_state.orchestrator.is_generation,
}
return result
@@ -156,7 +156,7 @@ async def get_model_info():
@app.get("/get_server_info")
async def get_server_info():
return {
**dataclasses.asdict(_global_state.tokenizer_manager.server_args),
**dataclasses.asdict(_global_state.orchestrator.server_args),
**_global_state.scheduler_info,
"version": __version__,
}
@@ -170,7 +170,7 @@ async def generate_request(obj: GenerateReqInput, request: Request):
async def stream_results() -> AsyncIterator[bytes]:
try:
async for out in _global_state.tokenizer_manager.generate_request(
async for out in _global_state.orchestrator.generate_request(
obj, request
):
yield b"data: " + orjson.dumps(
@@ -186,11 +186,11 @@ async def generate_request(obj: GenerateReqInput, request: Request):
return StreamingResponse(
stream_results(),
media_type="text/event-stream",
background=_global_state.tokenizer_manager.create_abort_task(obj),
background=_global_state.orchestrator.create_abort_task(obj),
)
else:
try:
ret = await _global_state.tokenizer_manager.generate_request(
ret = await _global_state.orchestrator.generate_request(
obj, request
).__anext__()
return ret
@@ -203,7 +203,7 @@ async def generate_request(obj: GenerateReqInput, request: Request):
async def encode_request(obj: EmbeddingReqInput, request: Request):
"""Handle an embedding request."""
try:
ret = await _global_state.tokenizer_manager.generate_request(
ret = await _global_state.orchestrator.generate_request(
obj, request
).__anext__()
return ret
@@ -215,7 +215,7 @@ async def encode_request(obj: EmbeddingReqInput, request: Request):
async def classify_request(obj: EmbeddingReqInput, request: Request):
"""Handle a reward model request. Now the arguments and return values are the same as embedding models."""
try:
ret = await _global_state.tokenizer_manager.generate_request(
ret = await _global_state.orchestrator.generate_request(
obj, request
).__anext__()
return ret
@@ -226,7 +226,7 @@ async def classify_request(obj: EmbeddingReqInput, request: Request):
@app.post("/flush_cache")
async def flush_cache():
"""Flush the radix cache."""
_global_state.tokenizer_manager.flush_cache()
_global_state.orchestrator.flush_cache()
return Response(
content="Cache flushed.\nPlease check backend logs for more details. "
"(When there are running or waiting requests, the operation will not be performed.)\n",
@@ -237,7 +237,7 @@ async def flush_cache():
@app.api_route("/start_profile", methods=["GET", "POST"])
async def start_profile_async():
"""Start profiling."""
_global_state.tokenizer_manager.start_profile()
_global_state.orchestrator.start_profile()
return Response(
content="Start profiling.\n",
status_code=200,
@@ -247,7 +247,7 @@ async def start_profile_async():
@app.api_route("/stop_profile", methods=["GET", "POST"])
async def stop_profile_async():
"""Stop profiling."""
_global_state.tokenizer_manager.stop_profile()
_global_state.orchestrator.stop_profile()
return Response(
content="Stop profiling. This will take some time.\n",
status_code=200,
@@ -257,7 +257,7 @@ async def stop_profile_async():
@app.post("/update_weights_from_disk")
async def update_weights_from_disk(obj: UpdateWeightFromDiskReqInput, request: Request):
"""Update the weights from disk in-place without re-launching the server."""
success, message = await _global_state.tokenizer_manager.update_weights_from_disk(
success, message = await _global_state.orchestrator.update_weights_from_disk(
obj, request
)
content = {"success": success, "message": message}
@@ -278,7 +278,7 @@ async def init_weights_update_group(
obj: InitWeightsUpdateGroupReqInput, request: Request
):
"""Initialize the parameter update group."""
success, message = await _global_state.tokenizer_manager.init_weights_update_group(
success, message = await _global_state.orchestrator.init_weights_update_group(
obj, request
)
content = {"success": success, "message": message}
@@ -293,10 +293,8 @@ async def update_weights_from_distributed(
obj: UpdateWeightsFromDistributedReqInput, request: Request
):
"""Update model parameter from distributed online."""
success, message = (
await _global_state.tokenizer_manager.update_weights_from_distributed(
obj, request
)
success, message = await _global_state.orchestrator.update_weights_from_distributed(
obj, request
)
content = {"success": success, "message": message}
if success:
@@ -309,7 +307,7 @@ async def update_weights_from_distributed(
async def get_weights_by_name(obj: GetWeightsByNameReqInput, request: Request):
"""Get model parameter by name."""
try:
ret = await _global_state.tokenizer_manager.get_weights_by_name(obj, request)
ret = await _global_state.orchestrator.get_weights_by_name(obj, request)
if ret is None:
return _create_error_response("Get parameter by name failed")
else:
@@ -324,7 +322,7 @@ async def release_memory_occupation(
):
"""Release GPU occupation temporarily"""
try:
await _global_state.tokenizer_manager.release_memory_occupation(obj, request)
await _global_state.orchestrator.release_memory_occupation(obj, request)
except Exception as e:
return _create_error_response(e)
@@ -335,7 +333,7 @@ async def resume_memory_occupation(
):
"""Resume GPU occupation"""
try:
await _global_state.tokenizer_manager.resume_memory_occupation(obj, request)
await _global_state.orchestrator.resume_memory_occupation(obj, request)
except Exception as e:
return _create_error_response(e)
@@ -344,7 +342,7 @@ async def resume_memory_occupation(
async def open_session(obj: OpenSessionReqInput, request: Request):
"""Open a session, and return its unique session id."""
try:
session_id = await _global_state.tokenizer_manager.open_session(obj, request)
session_id = await _global_state.orchestrator.open_session(obj, request)
if session_id is None:
raise Exception(
"Failed to open the session. Check if a session with the same id is still open."
@@ -358,7 +356,7 @@ async def open_session(obj: OpenSessionReqInput, request: Request):
async def close_session(obj: CloseSessionReqInput, request: Request):
"""Close the session"""
try:
await _global_state.tokenizer_manager.close_session(obj, request)
await _global_state.orchestrator.close_session(obj, request)
return Response(status_code=200)
except Exception as e:
return _create_error_response(e)
@@ -367,7 +365,7 @@ async def close_session(obj: CloseSessionReqInput, request: Request):
@app.api_route("/configure_logging", methods=["GET", "POST"])
async def configure_logging(obj: ConfigureLoggingReq, request: Request):
"""Close the session"""
_global_state.tokenizer_manager.configure_logging(obj)
_global_state.orchestrator.configure_logging(obj)
return Response(status_code=200)
@@ -398,24 +396,24 @@ async def function_call_request(obj: FunctionCallReqInput, request: Request):
@app.post("/v1/completions")
async def openai_v1_completions(raw_request: Request):
return await v1_completions(_global_state.tokenizer_manager, raw_request)
return await v1_completions(_global_state.orchestrator, raw_request)
@app.post("/v1/chat/completions")
async def openai_v1_chat_completions(raw_request: Request):
return await v1_chat_completions(_global_state.tokenizer_manager, raw_request)
return await v1_chat_completions(_global_state.orchestrator, raw_request)
@app.post("/v1/embeddings", response_class=ORJSONResponse)
async def openai_v1_embeddings(raw_request: Request):
response = await v1_embeddings(_global_state.tokenizer_manager, raw_request)
response = await v1_embeddings(_global_state.orchestrator, raw_request)
return response
@app.get("/v1/models", response_class=ORJSONResponse)
def available_models():
"""Show available models."""
served_model_names = [_global_state.tokenizer_manager.served_model_name]
served_model_names = [_global_state.orchestrator.served_model_name]
model_cards = []
for served_model_name in served_model_names:
model_cards.append(ModelCard(id=served_model_name, root=served_model_name))
@@ -425,7 +423,7 @@ def available_models():
@app.post("/v1/files")
async def openai_v1_files(file: UploadFile = File(...), purpose: str = Form("batch")):
return await v1_files_create(
file, purpose, _global_state.tokenizer_manager.server_args.file_storage_pth
file, purpose, _global_state.orchestrator.server_args.file_storage_pth
)
@@ -437,13 +435,13 @@ async def delete_file(file_id: str):
@app.post("/v1/batches")
async def openai_v1_batches(raw_request: Request):
return await v1_batches(_global_state.tokenizer_manager, raw_request)
return await v1_batches(_global_state.orchestrator, raw_request)
@app.post("/v1/batches/{batch_id}/cancel")
async def cancel_batches(batch_id: str):
# https://platform.openai.com/docs/api-reference/batch/cancel
return await v1_cancel_batch(_global_state.tokenizer_manager, batch_id)
return await v1_cancel_batch(_global_state.orchestrator, batch_id)
@app.get("/v1/batches/{batch_id}")
@@ -492,18 +490,18 @@ def launch_server(
- HTTP server: A FastAPI server that routes requests to the engine.
- The engine consists of three components:
1. TokenizerManager: Tokenizes the requests and sends them to the scheduler.
1. StdOrchestrator: Tokenizes the requests and sends them to the scheduler.
2. Scheduler (subprocess): Receives requests from the Tokenizer Manager, schedules batches, forwards them, and sends the output tokens to the Detokenizer Manager.
3. DetokenizerManager (subprocess): Detokenizes the output tokens and sends the result back to the Tokenizer Manager.
Note:
1. The HTTP server, Engine, and TokenizerManager both run in the main process.
1. The HTTP server, Engine, and StdOrchestrator both run in the main process.
2. Inter-process communication is done through ICP (each process uses a different port) via the ZMQ library.
"""
tokenizer_manager, scheduler_info = _launch_subprocesses(server_args=server_args)
orchestrator, scheduler_info = _launch_subprocesses(server_args=server_args)
set_global_state(
_GlobalState(
tokenizer_manager=tokenizer_manager,
orchestrator=orchestrator,
scheduler_info=scheduler_info,
)
)
@@ -523,7 +521,7 @@ def launch_server(
args=(
server_args,
pipe_finish_writer,
_global_state.tokenizer_manager.image_token_id,
_global_state.orchestrator.image_token_id,
),
)
t.start()