Implement Standalone gRPC Server for SGLang Python Scheduler (#10283)
This commit is contained in:
580
python/sglang/srt/entrypoints/grpc_request_manager.py
Normal file
580
python/sglang/srt/entrypoints/grpc_request_manager.py
Normal file
@@ -0,0 +1,580 @@
|
||||
"""
|
||||
gRPC Request Manager - Orchestrates request lifecycle without tokenization.
|
||||
Mimics TokenizerManager's state management and ZMQ communication patterns.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import dataclasses
|
||||
import logging
|
||||
import os
|
||||
import signal
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import grpc
|
||||
import zmq
|
||||
import zmq.asyncio
|
||||
|
||||
from sglang.srt.managers.io_struct import (
|
||||
AbortReq,
|
||||
BatchEmbeddingOut,
|
||||
BatchTokenIDOut,
|
||||
HealthCheckOutput,
|
||||
TokenizedEmbeddingReqInput,
|
||||
TokenizedGenerateReqInput,
|
||||
)
|
||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||
from sglang.srt.utils import get_zmq_socket, kill_process_tree
|
||||
from sglang.utils import get_exception_traceback
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class GrpcSignalHandler:
|
||||
"""Minimal signal handler for gRPC server - delegates real crash handling to scheduler."""
|
||||
|
||||
def __init__(self, grpc_manager):
|
||||
self.grpc_manager = grpc_manager
|
||||
|
||||
def sigterm_handler(self, signum=None, frame=None):
|
||||
"""Handle SIGTERM by gracefully shutting down gRPC server."""
|
||||
logger.warning(
|
||||
f"SIGTERM received. {signum=} {frame=}. Shutting down gRPC server..."
|
||||
)
|
||||
self.grpc_manager.gracefully_exit = True
|
||||
|
||||
def running_phase_sigquit_handler(self, signum=None, frame=None):
|
||||
"""Handle SIGQUIT from failed scheduler process."""
|
||||
logger.error(
|
||||
"Received SIGQUIT from scheduler process. Scheduler failed, shutting down gRPC server."
|
||||
)
|
||||
logger.info(
|
||||
"Note: Crash dumps are handled by the scheduler process, not the gRPC server."
|
||||
)
|
||||
# Just exit cleanly - the scheduler handles crash dumps
|
||||
kill_process_tree(os.getpid(), include_parent=True)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class GrpcReqState:
|
||||
"""State tracking for a gRPC request."""
|
||||
|
||||
# Request identification
|
||||
request_id: str
|
||||
grpc_context: Optional[grpc.aio.ServicerContext]
|
||||
|
||||
# Communication
|
||||
out_queue: asyncio.Queue
|
||||
finished: bool
|
||||
event: asyncio.Event
|
||||
obj: Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput]
|
||||
|
||||
# Metrics (same as TokenizerManager's ReqState)
|
||||
created_time: float
|
||||
finished_time: float = 0.0
|
||||
first_token_time: float = 0.0
|
||||
last_time: float = 0.0
|
||||
last_completion_tokens: int = 1
|
||||
|
||||
# Streaming state
|
||||
last_output_offset: int = 0
|
||||
stream_finished: bool = False
|
||||
|
||||
# Output accumulation
|
||||
text: str = ""
|
||||
output_ids: List[int] = dataclasses.field(default_factory=list)
|
||||
input_token_logprobs_val: List[float] = dataclasses.field(default_factory=list)
|
||||
input_token_logprobs_idx: List[int] = dataclasses.field(default_factory=list)
|
||||
output_token_logprobs_val: List[float] = dataclasses.field(default_factory=list)
|
||||
output_token_logprobs_idx: List[int] = dataclasses.field(default_factory=list)
|
||||
input_top_logprobs_val: List[List[float]] = dataclasses.field(default_factory=list)
|
||||
input_top_logprobs_idx: List[List[int]] = dataclasses.field(default_factory=list)
|
||||
output_top_logprobs_val: List[List[float]] = dataclasses.field(default_factory=list)
|
||||
output_top_logprobs_idx: List[List[int]] = dataclasses.field(default_factory=list)
|
||||
|
||||
# Session state
|
||||
session_id: Optional[str] = None
|
||||
is_session_request: bool = False
|
||||
|
||||
|
||||
class GrpcRequestManager:
|
||||
"""
|
||||
Manages gRPC request lifecycle, mimicking TokenizerManager's orchestration
|
||||
behaviors without tokenization.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
server_args: ServerArgs,
|
||||
port_args: PortArgs,
|
||||
):
|
||||
"""Initialize the gRPC request manager."""
|
||||
self.server_args = server_args
|
||||
self.port_args = port_args
|
||||
|
||||
# ZMQ Communication Setup (same pattern as TokenizerManager)
|
||||
context = zmq.asyncio.Context(2)
|
||||
|
||||
# Socket for receiving outputs from scheduler
|
||||
self.recv_from_scheduler = get_zmq_socket(
|
||||
context, zmq.PULL, port_args.detokenizer_ipc_name, bind=True
|
||||
)
|
||||
|
||||
# Socket for sending requests to scheduler
|
||||
self.send_to_scheduler = get_zmq_socket(
|
||||
context, zmq.PUSH, port_args.scheduler_input_ipc_name, bind=True
|
||||
)
|
||||
|
||||
# State Management (from TokenizerManager)
|
||||
self.rid_to_state: Dict[str, GrpcReqState] = {}
|
||||
self.asyncio_tasks: set = set()
|
||||
self.gracefully_exit = False
|
||||
self.no_create_loop = False
|
||||
self.event_loop = None
|
||||
|
||||
# Pause/Resume Control
|
||||
self.is_pause = False
|
||||
self.is_pause_cond = asyncio.Condition()
|
||||
|
||||
# Metrics
|
||||
self.request_counter = 0
|
||||
self.request_counter_lock = asyncio.Lock()
|
||||
self.last_receive_tstamp = time.time()
|
||||
|
||||
# Crash dump for debugging
|
||||
self.crash_dump_request_list = []
|
||||
self.crash_dump_performed = False
|
||||
|
||||
logger.info(
|
||||
f"GrpcRequestManager initialized with ZMQ IPC: "
|
||||
f"recv={port_args.detokenizer_ipc_name}, "
|
||||
f"send={port_args.scheduler_input_ipc_name}"
|
||||
)
|
||||
|
||||
async def generate_request(
|
||||
self,
|
||||
obj: TokenizedGenerateReqInput,
|
||||
request_id: Optional[str] = None,
|
||||
grpc_context: Optional[grpc.aio.ServicerContext] = None,
|
||||
) -> asyncio.Queue:
|
||||
"""
|
||||
Submit a generation request to the scheduler.
|
||||
Returns a queue for streaming outputs.
|
||||
"""
|
||||
# Generate request ID if not provided
|
||||
if request_id is None:
|
||||
async with self.request_counter_lock:
|
||||
request_id = f"grpc-{self.request_counter}"
|
||||
self.request_counter += 1
|
||||
|
||||
obj.rid = request_id
|
||||
|
||||
# TODO: support log_request
|
||||
|
||||
# Create request state
|
||||
state = GrpcReqState(
|
||||
request_id=request_id,
|
||||
grpc_context=grpc_context,
|
||||
out_queue=asyncio.Queue(),
|
||||
finished=False,
|
||||
event=asyncio.Event(),
|
||||
obj=obj,
|
||||
created_time=time.time(),
|
||||
)
|
||||
|
||||
# Track session if needed
|
||||
if hasattr(obj, "session_params") and obj.session_params:
|
||||
state.session_id = obj.session_params.session_id
|
||||
state.is_session_request = True
|
||||
|
||||
# Register state
|
||||
self.rid_to_state[request_id] = state
|
||||
self.record_request_for_crash_dump(obj)
|
||||
|
||||
# Send to scheduler via ZMQ
|
||||
try:
|
||||
await self._send_to_scheduler(obj)
|
||||
except Exception as e:
|
||||
# Clean up on failure
|
||||
del self.rid_to_state[request_id]
|
||||
raise RuntimeError(f"Failed to send request to scheduler: {e}")
|
||||
|
||||
return state.out_queue
|
||||
|
||||
async def embedding_request(
|
||||
self,
|
||||
obj: TokenizedEmbeddingReqInput,
|
||||
request_id: Optional[str] = None,
|
||||
) -> asyncio.Future:
|
||||
"""
|
||||
Submit an embedding request to the scheduler.
|
||||
Returns a future that will contain the embedding result.
|
||||
"""
|
||||
# Generate request ID if not provided
|
||||
if request_id is None:
|
||||
async with self.request_counter_lock:
|
||||
request_id = f"grpc-embed-{self.request_counter}"
|
||||
self.request_counter += 1
|
||||
|
||||
obj.rid = request_id
|
||||
|
||||
# Create request state
|
||||
state = GrpcReqState(
|
||||
request_id=request_id,
|
||||
grpc_context=None,
|
||||
out_queue=asyncio.Queue(),
|
||||
finished=False,
|
||||
event=asyncio.Event(),
|
||||
obj=obj,
|
||||
created_time=time.time(),
|
||||
)
|
||||
|
||||
# Register state
|
||||
self.rid_to_state[request_id] = state
|
||||
|
||||
# Create future for result
|
||||
future = asyncio.Future()
|
||||
|
||||
# Send to scheduler
|
||||
try:
|
||||
await self._send_to_scheduler(obj)
|
||||
except Exception as e:
|
||||
del self.rid_to_state[request_id]
|
||||
future.set_exception(e)
|
||||
return future
|
||||
|
||||
# Wait for result in background
|
||||
async def wait_for_result():
|
||||
try:
|
||||
# Wait for completion
|
||||
await state.event.wait()
|
||||
# Get result from queue
|
||||
result = await state.out_queue.get()
|
||||
future.set_result(result)
|
||||
except Exception as e:
|
||||
future.set_exception(e)
|
||||
finally:
|
||||
# Clean up
|
||||
if request_id in self.rid_to_state:
|
||||
del self.rid_to_state[request_id]
|
||||
|
||||
asyncio.create_task(wait_for_result())
|
||||
return future
|
||||
|
||||
async def abort_request(self, request_id: str) -> bool:
|
||||
"""Abort a running request."""
|
||||
if request_id not in self.rid_to_state:
|
||||
return False
|
||||
|
||||
# Send abort to scheduler
|
||||
abort_req = AbortReq(rid=request_id)
|
||||
try:
|
||||
await self._send_to_scheduler(abort_req)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send abort request: {e}")
|
||||
return False
|
||||
|
||||
# Mark as finished
|
||||
state = self.rid_to_state.get(request_id)
|
||||
if state:
|
||||
state.finished = True
|
||||
state.stream_finished = True
|
||||
state.event.set()
|
||||
|
||||
# Send abort notification to output queue
|
||||
await state.out_queue.put({"error": "Request aborted", "abort": True})
|
||||
|
||||
return True
|
||||
|
||||
async def pause_generation(self):
|
||||
"""Pause generation processing."""
|
||||
async with self.is_pause_cond:
|
||||
self.is_pause = True
|
||||
logger.info("Generation paused")
|
||||
|
||||
async def resume_generation(self):
|
||||
"""Resume generation processing."""
|
||||
async with self.is_pause_cond:
|
||||
self.is_pause = False
|
||||
self.is_pause_cond.notify_all()
|
||||
logger.info("Generation resumed")
|
||||
|
||||
async def handle_loop(self):
|
||||
"""
|
||||
Main event loop - processes outputs from scheduler.
|
||||
Mimics TokenizerManager's handle_loop.
|
||||
"""
|
||||
while not self.gracefully_exit:
|
||||
try:
|
||||
# Receive from scheduler
|
||||
recv_obj = await self.recv_from_scheduler.recv_pyobj()
|
||||
self.last_receive_tstamp = time.time()
|
||||
|
||||
# Check for pause
|
||||
async with self.is_pause_cond:
|
||||
while self.is_pause:
|
||||
await self.is_pause_cond.wait()
|
||||
|
||||
# Handle different output types
|
||||
if isinstance(recv_obj, BatchTokenIDOut):
|
||||
await self._handle_batch_output(recv_obj)
|
||||
elif isinstance(recv_obj, BatchEmbeddingOut):
|
||||
await self._handle_embedding_output(recv_obj)
|
||||
elif isinstance(recv_obj, HealthCheckOutput):
|
||||
await self._handle_health_check_output(recv_obj)
|
||||
else:
|
||||
logger.warning(f"Unknown output type: {type(recv_obj)}")
|
||||
|
||||
except zmq.error.Again:
|
||||
# Timeout, check if we should exit
|
||||
if self.gracefully_exit:
|
||||
break
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.error(f"Handle loop error: {e}\n{get_exception_traceback()}")
|
||||
if self.gracefully_exit:
|
||||
break
|
||||
|
||||
async def _handle_batch_output(self, batch_out: BatchTokenIDOut):
|
||||
"""Handle batch generation output from scheduler."""
|
||||
# Process each request in the batch
|
||||
for i, rid in enumerate(batch_out.rids):
|
||||
if rid not in self.rid_to_state:
|
||||
continue
|
||||
|
||||
state = self.rid_to_state[rid]
|
||||
|
||||
# Update metrics
|
||||
now = time.time()
|
||||
if state.first_token_time == 0.0:
|
||||
state.first_token_time = now
|
||||
state.last_time = now
|
||||
|
||||
# Extract output for this request
|
||||
output_data = {
|
||||
"request_id": rid,
|
||||
"text": batch_out.decoded_texts[i] if batch_out.decoded_texts else "",
|
||||
"token_ids": batch_out.output_ids[i] if batch_out.output_ids else [],
|
||||
"finished": batch_out.finished_reasons[i] is not None,
|
||||
"meta_info": {
|
||||
"prompt_tokens": (
|
||||
batch_out.prompt_tokens[i] if batch_out.prompt_tokens else 0
|
||||
),
|
||||
"completion_tokens": (
|
||||
batch_out.completion_tokens[i]
|
||||
if batch_out.completion_tokens
|
||||
else 0
|
||||
),
|
||||
"finish_reason": (
|
||||
str(batch_out.finished_reasons[i])
|
||||
if batch_out.finished_reasons[i]
|
||||
else None
|
||||
),
|
||||
},
|
||||
}
|
||||
|
||||
# Add logprobs if available
|
||||
if batch_out.output_token_logprobs_val and i < len(
|
||||
batch_out.output_token_logprobs_val
|
||||
):
|
||||
output_data["logprobs"] = {
|
||||
"tokens": batch_out.output_token_logprobs_val[i],
|
||||
"top_logprobs": (
|
||||
batch_out.output_top_logprobs_val[i]
|
||||
if batch_out.output_top_logprobs_val
|
||||
and i < len(batch_out.output_top_logprobs_val)
|
||||
else None
|
||||
),
|
||||
}
|
||||
|
||||
# Update state
|
||||
if output_data["text"]:
|
||||
state.text += output_data["text"][state.last_output_offset :]
|
||||
state.last_output_offset = len(output_data["text"])
|
||||
|
||||
if output_data["token_ids"]:
|
||||
state.output_ids.extend(output_data["token_ids"])
|
||||
|
||||
# Send to output queue
|
||||
await state.out_queue.put(output_data)
|
||||
|
||||
# Handle completion
|
||||
if output_data["finished"]:
|
||||
state.finished = True
|
||||
state.finished_time = now
|
||||
state.stream_finished = True
|
||||
state.event.set()
|
||||
|
||||
# Remove from tracking after a delay
|
||||
async def cleanup():
|
||||
await asyncio.sleep(5.0)
|
||||
if rid in self.rid_to_state:
|
||||
del self.rid_to_state[rid]
|
||||
|
||||
asyncio.create_task(cleanup())
|
||||
|
||||
async def _handle_embedding_output(self, batch_out: BatchEmbeddingOut):
|
||||
"""Handle batch embedding output from scheduler."""
|
||||
for i, rid in enumerate(batch_out.rids):
|
||||
if rid not in self.rid_to_state:
|
||||
continue
|
||||
|
||||
state = self.rid_to_state[rid]
|
||||
|
||||
# Create result
|
||||
result = {
|
||||
"request_id": rid,
|
||||
"embedding": batch_out.embeddings[i],
|
||||
"prompt_tokens": (
|
||||
batch_out.prompt_tokens[i] if batch_out.prompt_tokens else 0
|
||||
),
|
||||
"finish_reason": (
|
||||
batch_out.finish_reason[i] if batch_out.finish_reason else None
|
||||
),
|
||||
}
|
||||
|
||||
# Send result
|
||||
await state.out_queue.put(result)
|
||||
|
||||
# Mark as finished
|
||||
state.finished = True
|
||||
state.finished_time = time.time()
|
||||
state.event.set()
|
||||
|
||||
async def _handle_health_check_output(self, health_out: HealthCheckOutput):
|
||||
"""Handle health check output from scheduler."""
|
||||
rid = health_out.rid
|
||||
|
||||
if rid not in self.rid_to_state:
|
||||
logger.warning(f"Health check output for unknown request: {rid}")
|
||||
return
|
||||
|
||||
state = self.rid_to_state[rid]
|
||||
|
||||
# Create health check result
|
||||
result = {
|
||||
"request_id": rid,
|
||||
"healthy": True, # If we got a response, scheduler is healthy
|
||||
"output_text": (
|
||||
health_out.output_str if hasattr(health_out, "output_str") else ""
|
||||
),
|
||||
"finish_reason": (
|
||||
health_out.finish_reason
|
||||
if hasattr(health_out, "finish_reason")
|
||||
else "stop"
|
||||
),
|
||||
}
|
||||
|
||||
# Send result
|
||||
await state.out_queue.put(result)
|
||||
|
||||
# Mark as finished
|
||||
state.finished = True
|
||||
state.finished_time = time.time()
|
||||
state.event.set()
|
||||
|
||||
async def _send_to_scheduler(self, obj):
|
||||
"""Send an object to the scheduler via ZMQ."""
|
||||
try:
|
||||
self.send_to_scheduler.send_pyobj(obj)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send to scheduler: {e}")
|
||||
raise
|
||||
|
||||
def record_request_for_crash_dump(self, obj):
|
||||
"""Record request for potential crash dump."""
|
||||
if len(self.crash_dump_request_list) < 100:
|
||||
self.crash_dump_request_list.append(
|
||||
{
|
||||
"time": time.time(),
|
||||
"request_id": getattr(obj, "rid", "unknown"),
|
||||
"type": type(obj).__name__,
|
||||
}
|
||||
)
|
||||
|
||||
async def shutdown(self):
|
||||
"""Gracefully shutdown the request manager."""
|
||||
logger.info("Shutting down GrpcRequestManager")
|
||||
self.gracefully_exit = True
|
||||
|
||||
# Cancel all pending requests
|
||||
for rid, state in self.rid_to_state.items():
|
||||
if not state.finished:
|
||||
await state.out_queue.put(
|
||||
{"error": "Server shutting down", "shutdown": True}
|
||||
)
|
||||
state.finished = True
|
||||
state.event.set()
|
||||
|
||||
# Wait for tasks to complete
|
||||
if self.asyncio_tasks:
|
||||
await asyncio.gather(*list(self.asyncio_tasks), return_exceptions=True)
|
||||
|
||||
# Close ZMQ sockets
|
||||
self.recv_from_scheduler.close()
|
||||
self.send_to_scheduler.close()
|
||||
|
||||
logger.info("GrpcRequestManager shutdown complete")
|
||||
|
||||
def get_server_info(self) -> Dict[str, Any]:
|
||||
"""Get server information for health checks."""
|
||||
return {
|
||||
"active_requests": len(self.rid_to_state),
|
||||
"paused": self.is_pause,
|
||||
"last_receive_time": self.last_receive_tstamp,
|
||||
}
|
||||
|
||||
def auto_create_handle_loop(self):
|
||||
"""Automatically create and start the handle_loop task, matching TokenizerManager pattern."""
|
||||
if self.no_create_loop:
|
||||
return
|
||||
|
||||
self.no_create_loop = True
|
||||
loop = asyncio.get_event_loop()
|
||||
self.asyncio_tasks.add(
|
||||
loop.create_task(print_exception_wrapper(self.handle_loop))
|
||||
)
|
||||
|
||||
self.event_loop = loop
|
||||
|
||||
# We cannot add signal handler when the grpc manager is not in
|
||||
# the main thread due to the CPython limitation.
|
||||
if threading.current_thread() is threading.main_thread():
|
||||
signal_handler = GrpcSignalHandler(self)
|
||||
loop.add_signal_handler(signal.SIGTERM, signal_handler.sigterm_handler)
|
||||
# Update the signal handler for the process. It overrides the sigquit handler in the launch phase.
|
||||
loop.add_signal_handler(
|
||||
signal.SIGQUIT, signal_handler.running_phase_sigquit_handler
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"Signal handler is not added because the grpc request manager is "
|
||||
"not in the main thread. This disables graceful shutdown of the "
|
||||
"grpc request manager when SIGTERM is received."
|
||||
)
|
||||
self.asyncio_tasks.add(
|
||||
loop.create_task(print_exception_wrapper(self.sigterm_watchdog))
|
||||
)
|
||||
|
||||
async def sigterm_watchdog(self):
|
||||
"""Watchdog to handle SIGTERM gracefully, matching TokenizerManager pattern."""
|
||||
while not self.gracefully_exit:
|
||||
await asyncio.sleep(1.0)
|
||||
|
||||
|
||||
async def print_exception_wrapper(func):
|
||||
"""
|
||||
Sometimes an asyncio function does not print exception.
|
||||
We do another wrapper to handle the exception.
|
||||
"""
|
||||
try:
|
||||
await func()
|
||||
except Exception:
|
||||
traceback = get_exception_traceback()
|
||||
logger.error(f"GrpcRequestManager hit an exception: {traceback}")
|
||||
if hasattr(func, "__self__") and isinstance(func.__self__, GrpcRequestManager):
|
||||
func.__self__.dump_requests_before_crash()
|
||||
kill_process_tree(os.getpid(), include_parent=True)
|
||||
sys.exit(1)
|
||||
Reference in New Issue
Block a user