Files
sglang/python/sglang/srt/entrypoints/grpc_request_manager.py

856 lines
32 KiB
Python

"""
gRPC Request Manager - Orchestrates request lifecycle without tokenization.
Mimics TokenizerManager's state management and ZMQ communication patterns.
"""
import asyncio
import copy
import dataclasses
import logging
import os
import signal
import sys
import threading
import time
import uuid
from typing import Any, AsyncGenerator, Dict, List, Optional, Union
import grpc
import zmq
import zmq.asyncio
from sglang.srt.managers.disagg_service import start_disagg_service
from sglang.srt.managers.io_struct import (
AbortReq,
BatchEmbeddingOut,
BatchTokenIDOut,
HealthCheckOutput,
TokenizedEmbeddingReqInput,
TokenizedGenerateReqInput,
)
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import get_zmq_socket, kill_process_tree
from sglang.utils import get_exception_traceback
logger = logging.getLogger(__name__)
class GrpcSignalHandler:
"""Minimal signal handler for gRPC server - delegates real crash handling to scheduler."""
def __init__(self, grpc_manager):
self.grpc_manager = grpc_manager
def sigterm_handler(self, signum=None, frame=None):
"""Handle SIGTERM by gracefully shutting down gRPC server."""
logger.warning(
f"SIGTERM received. {signum=} {frame=}. Shutting down gRPC server..."
)
self.grpc_manager.gracefully_exit = True
def running_phase_sigquit_handler(self, signum=None, frame=None):
"""Handle SIGQUIT from failed scheduler process."""
logger.error(
"Received SIGQUIT from scheduler process. Scheduler failed, shutting down gRPC server."
)
logger.info(
"Note: Crash dumps are handled by the scheduler process, not the gRPC server."
)
# Just exit cleanly - the scheduler handles crash dumps
kill_process_tree(os.getpid(), include_parent=True)
@dataclasses.dataclass
class GrpcReqState:
"""State tracking for a gRPC request."""
# Request identification
request_id: str
grpc_context: Optional[grpc.aio.ServicerContext]
# Communication
out_queue: asyncio.Queue
finished: bool
event: asyncio.Event
obj: Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput]
# Metrics (same as TokenizerManager's ReqState)
created_time: float
finished_time: float = 0.0
first_token_time: float = 0.0
last_time: float = 0.0
last_completion_tokens: int = 1
# Streaming state
stream_finished: bool = False
input_logprobs_sent: bool = False # Track if input logprobs were sent in streaming
# Token accumulation (for non-streaming)
output_ids: List[int] = dataclasses.field(default_factory=list)
input_token_logprobs_val: List[float] = dataclasses.field(default_factory=list)
input_token_logprobs_idx: List[int] = dataclasses.field(default_factory=list)
output_token_logprobs_val: List[float] = dataclasses.field(default_factory=list)
output_token_logprobs_idx: List[int] = dataclasses.field(default_factory=list)
input_top_logprobs_val: List[List[float]] = dataclasses.field(default_factory=list)
input_top_logprobs_idx: List[List[int]] = dataclasses.field(default_factory=list)
output_top_logprobs_val: List[List[float]] = dataclasses.field(default_factory=list)
output_top_logprobs_idx: List[List[int]] = dataclasses.field(default_factory=list)
# Session state
session_id: Optional[str] = None
is_session_request: bool = False
class GrpcRequestManager:
"""
Manages gRPC request lifecycle, mimicking TokenizerManager's orchestration
behaviors without tokenization.
"""
def __init__(
self,
server_args: ServerArgs,
port_args: PortArgs,
):
"""Initialize the gRPC request manager."""
self.server_args = server_args
self.port_args = port_args
# ZMQ Communication Setup (same pattern as TokenizerManager)
self.context = zmq.asyncio.Context(2)
# Socket for receiving outputs from scheduler
self.recv_from_scheduler = get_zmq_socket(
self.context, zmq.PULL, port_args.detokenizer_ipc_name, bind=True
)
# Socket for sending requests to scheduler
self.send_to_scheduler = get_zmq_socket(
self.context, zmq.PUSH, port_args.scheduler_input_ipc_name, bind=True
)
# State Management (from TokenizerManager)
self.rid_to_state: Dict[str, GrpcReqState] = {}
self.asyncio_tasks: set = set()
self.gracefully_exit = False
self.no_create_loop = False
self.event_loop = None
# Pause/Resume Control
self.is_pause = False
self.is_pause_cond = asyncio.Condition()
# Metrics
self.last_receive_tstamp = time.time()
# Crash dump for debugging
self.crash_dump_request_list = []
self.crash_dump_performed = False
# Bootstrap server for disaggregation mode
self.bootstrap_server = start_disagg_service(server_args)
logger.info(
f"GrpcRequestManager initialized with ZMQ IPC: "
f"recv={port_args.detokenizer_ipc_name}, "
f"send={port_args.scheduler_input_ipc_name}"
)
if self.bootstrap_server:
logger.info(
f"Bootstrap server started for disaggregation mode: "
f"{server_args.disaggregation_mode}"
)
async def generate_request(
self,
obj: TokenizedGenerateReqInput,
request_id: Optional[str] = None,
grpc_context: Optional[grpc.aio.ServicerContext] = None,
) -> AsyncGenerator[Union[Dict, List[Dict]], None]:
"""
Submit a generation request to the scheduler with n>1 parallel sampling support.
This method implements the same two-phase approach as tokenizer_manager.py:
1. Phase 1: Send prefix caching request (max_new_tokens=0)
2. Phase 2: Send n generation requests that reuse the cached prefix
Yields individual responses for streaming, or aggregated responses for non-streaming.
"""
n = getattr(obj.sampling_params, "n", 1)
if n <= 1:
async for response in self._handle_single_request(
obj, request_id, grpc_context
):
yield response
return
# N>1 handling - two-phase approach
logger.debug(f"Multiple sampling request (n={n}), using two-phase approach")
# Generate base request ID if not provided
if request_id is None:
base_request_id = f"grpc-{uuid.uuid4().hex}"
else:
base_request_id = request_id
# Phase 1: Cache the common prefix
logger.debug(f"Phase 1: Caching prefix for request {base_request_id}")
prefix_obj = copy.copy(obj)
prefix_obj.sampling_params = copy.copy(obj.sampling_params)
prefix_obj.sampling_params.max_new_tokens = 0 # Prefill-only
prefix_obj.sampling_params.n = 1 # Don't replicate prefix request
# Send prefix caching request and consume response
async for _ in self._handle_single_request(
prefix_obj, f"{base_request_id}-prefix", grpc_context
):
# Consume prefix response (usually just one chunk with finish_reason)
pass
logger.debug(f"Phase 1 completed: Prefix cached for {base_request_id}")
# Phase 2: Generate n parallel requests
logger.debug(f"Phase 2: Generating {n} parallel requests")
generators = []
request_ids = []
for i in range(n):
# Create individual generation request
gen_obj = copy.copy(obj)
gen_obj.sampling_params = copy.copy(obj.sampling_params)
gen_obj.sampling_params.n = 1 # Each request generates 1 response
gen_request_id = f"{base_request_id}-{i}"
request_ids.append(gen_request_id)
# Start generation request
generators.append(
self._handle_single_request(gen_obj, gen_request_id, grpc_context)
)
# Handle response aggregation
is_stream = getattr(obj, "stream", False)
if not is_stream:
# Non-streaming: collect all responses and return as batch
logger.debug(f"Non-streaming mode: collecting {n} responses")
responses = []
for generator in generators:
async for response in generator:
responses.append(response)
yield responses # Return all responses as a batch
else:
# Streaming mode: multiplex responses with index for ordering
logger.debug(f"Streaming mode: multiplexing {n} streams")
rid_to_index = {rid: i for i, rid in enumerate(request_ids)}
# Create async tasks for all generators
task_map = {}
for generator in generators:
task = asyncio.create_task(generator.__anext__())
task_map[task] = generator
# Process responses as they arrive
while task_map:
done, _ = await asyncio.wait(
task_map.keys(), return_when=asyncio.FIRST_COMPLETED
)
for task in done:
generator = task_map.pop(task)
try:
response = await task
# Add index for client-side ordering
if isinstance(response, dict) and "meta_info" in response:
response_rid = response["meta_info"].get("id", "")
if response_rid in rid_to_index:
response["index"] = rid_to_index[response_rid]
yield response
# Create next task for this generator
next_task = asyncio.create_task(generator.__anext__())
task_map[next_task] = generator
except StopAsyncIteration:
# This generator is finished
pass
async def _handle_single_request(
self,
obj: TokenizedGenerateReqInput,
request_id: Optional[str] = None,
grpc_context: Optional[grpc.aio.ServicerContext] = None,
):
"""Handle a single request - core implementation without n>1 logic."""
# Generate request ID if not provided
if request_id is None:
request_id = f"grpc-{uuid.uuid4().hex}"
obj.rid = request_id
# Create and register request state
# TODO: support log_request
state = GrpcReqState(
request_id=request_id,
grpc_context=grpc_context,
out_queue=asyncio.Queue(),
finished=False,
event=asyncio.Event(),
obj=obj,
created_time=time.time(),
)
# Track session if needed
if hasattr(obj, "session_params") and obj.session_params:
state.session_id = obj.session_params.session_id
state.is_session_request = True
self.rid_to_state[request_id] = state
self.record_request_for_crash_dump(obj)
try:
# Send to scheduler - let exceptions bubble up to grpc_server.py
await self._send_to_scheduler(obj)
is_stream = getattr(obj, "stream", False)
while True:
# Client cancelled - notify scheduler and exit
if grpc_context and grpc_context.cancelled():
await self.abort_request(request_id)
return
try:
response = await asyncio.wait_for(state.out_queue.get(), timeout=4)
if is_stream:
yield response
# Non-streaming: yield final response with accumulated tokens from state
if isinstance(response, dict) and response.get("finished", False):
if not is_stream:
final_response = response.copy()
final_response["token_ids"] = state.output_ids
yield final_response
break
except asyncio.TimeoutError:
# Timeout waiting for response - abort and cleanup
logger.warning(
f"Timeout waiting for response for request {request_id}"
)
await self.abort_request(request_id)
return
finally:
# Always clean up request state when exiting
self._cleanup_request_state(request_id)
def _cleanup_request_state(self, request_id: str):
"""Clean up local request state (does not notify scheduler)."""
if request_id in self.rid_to_state:
del self.rid_to_state[request_id]
async def embedding_request(
self,
obj: TokenizedEmbeddingReqInput,
request_id: Optional[str] = None,
) -> asyncio.Future:
"""
Submit an embedding request to the scheduler.
Returns a future that will contain the embedding result.
"""
# Generate request ID if not provided
if request_id is None:
request_id = f"grpc-embed-{uuid.uuid4().hex}"
obj.rid = request_id
# Create request state
state = GrpcReqState(
request_id=request_id,
grpc_context=None,
out_queue=asyncio.Queue(),
finished=False,
event=asyncio.Event(),
obj=obj,
created_time=time.time(),
)
# Register state
self.rid_to_state[request_id] = state
# Create future for result
future = asyncio.Future()
# Send to scheduler
try:
await self._send_to_scheduler(obj)
except Exception as e:
del self.rid_to_state[request_id]
future.set_exception(e)
return future
# Wait for result in background
async def wait_for_result():
try:
# Wait for completion
await state.event.wait()
# Get result from queue
result = await state.out_queue.get()
future.set_result(result)
except Exception as e:
future.set_exception(e)
finally:
# Clean up
if request_id in self.rid_to_state:
del self.rid_to_state[request_id]
asyncio.create_task(wait_for_result())
return future
async def abort_request(self, request_id: str) -> bool:
"""Abort a running request."""
if request_id not in self.rid_to_state:
return False
# Send abort to scheduler
abort_req = AbortReq(rid=request_id)
try:
await self._send_to_scheduler(abort_req)
except Exception as e:
logger.error(f"Failed to send abort request: {e}")
return False
# Mark as finished
state = self.rid_to_state.get(request_id)
if state:
state.finished = True
state.stream_finished = True
state.event.set()
# Send abort notification to output queue
await state.out_queue.put({"error": "Request aborted", "abort": True})
return True
async def pause_generation(self):
"""Pause generation processing."""
async with self.is_pause_cond:
self.is_pause = True
logger.info("Generation paused")
async def resume_generation(self):
"""Resume generation processing."""
async with self.is_pause_cond:
self.is_pause = False
self.is_pause_cond.notify_all()
logger.info("Generation resumed")
async def handle_loop(self):
"""
Main event loop - processes outputs from scheduler.
Mimics TokenizerManager's handle_loop.
"""
while not self.gracefully_exit:
try:
# Receive from scheduler
recv_obj = await self.recv_from_scheduler.recv_pyobj()
self.last_receive_tstamp = time.time()
# Check for pause
async with self.is_pause_cond:
while self.is_pause:
await self.is_pause_cond.wait()
# Handle different output types
if isinstance(recv_obj, BatchTokenIDOut):
await self._handle_batch_output(recv_obj)
elif isinstance(recv_obj, BatchEmbeddingOut):
await self._handle_embedding_output(recv_obj)
elif isinstance(recv_obj, HealthCheckOutput):
await self._handle_health_check_output(recv_obj)
else:
logger.warning(f"Unknown output type: {type(recv_obj)}")
except zmq.error.Again:
# Timeout, check if we should exit
if self.gracefully_exit:
break
continue
except zmq.error.ZMQError as e:
# Socket closed or other ZMQ error - exit cleanly if shutting down
if self.gracefully_exit:
logger.debug(f"ZMQ recv interrupted during shutdown: {e}")
break
logger.error(
f"ZMQ error in handle loop: {e}\n{get_exception_traceback()}"
)
break
except Exception as e:
logger.error(f"Handle loop error: {e}\n{get_exception_traceback()}")
if self.gracefully_exit:
break
def _convert_logprob_style(
self,
state: GrpcReqState,
batch_out: BatchTokenIDOut,
batch_index: int,
):
"""
Convert and accumulate logprobs from batch output to state.
Follows the same logic as tokenizer_manager.convert_logprob_style.
"""
# Early exit if no input logprobs at all
if batch_out.input_token_logprobs_val is None:
return
# Accumulate input token logprobs (only if list is non-empty)
if len(batch_out.input_token_logprobs_val) > 0:
state.input_token_logprobs_val.extend(
batch_out.input_token_logprobs_val[batch_index]
)
state.input_token_logprobs_idx.extend(
batch_out.input_token_logprobs_idx[batch_index]
)
# Always accumulate output token logprobs
state.output_token_logprobs_val.extend(
batch_out.output_token_logprobs_val[batch_index]
)
state.output_token_logprobs_idx.extend(
batch_out.output_token_logprobs_idx[batch_index]
)
# Handle top logprobs if requested
if state.obj.top_logprobs_num > 0:
# Accumulate input top logprobs (only if list is non-empty)
if len(batch_out.input_top_logprobs_val) > 0:
state.input_top_logprobs_val.extend(
batch_out.input_top_logprobs_val[batch_index]
)
state.input_top_logprobs_idx.extend(
batch_out.input_top_logprobs_idx[batch_index]
)
# Always accumulate output top logprobs
state.output_top_logprobs_val.extend(
batch_out.output_top_logprobs_val[batch_index]
)
state.output_top_logprobs_idx.extend(
batch_out.output_top_logprobs_idx[batch_index]
)
async def _handle_batch_output(self, batch_out: BatchTokenIDOut):
"""Handle batch generation output from scheduler."""
# Process each request in the batch
for i, rid in enumerate(batch_out.rids):
if rid not in self.rid_to_state:
continue
state = self.rid_to_state[rid]
# Update metrics
now = time.time()
if state.first_token_time == 0.0:
state.first_token_time = now
state.last_time = now
# Extract output for this request
output_data = {
"request_id": rid,
"token_ids": batch_out.output_ids[i] if batch_out.output_ids else [],
"finished": batch_out.finished_reasons[i] is not None,
"meta_info": {
"prompt_tokens": (
batch_out.prompt_tokens[i] if batch_out.prompt_tokens else 0
),
"completion_tokens": (
batch_out.completion_tokens[i]
if batch_out.completion_tokens
else 0
),
"cached_tokens": (
batch_out.cached_tokens[i] if batch_out.cached_tokens else 0
),
"finish_reason": (
str(batch_out.finished_reasons[i])
if batch_out.finished_reasons[i]
else None
),
},
}
# Accumulate logprobs (following tokenizer_manager pattern)
if state.obj.return_logprob:
self._convert_logprob_style(state, batch_out, i)
# Send input logprobs based if available
if (
state.obj.return_logprob
and state.obj.logprob_start_len >= 0
and state.input_token_logprobs_val
):
if state.obj.stream and not state.input_logprobs_sent:
# Streaming: send input logprobs once in first chunk that has them
output_data["input_logprobs"] = {
"token_logprobs_val": state.input_token_logprobs_val,
"token_logprobs_idx": state.input_token_logprobs_idx,
"top_logprobs_val": state.input_top_logprobs_val,
"top_logprobs_idx": state.input_top_logprobs_idx,
}
state.input_logprobs_sent = True
elif not state.obj.stream and output_data["finished"]:
# Non-streaming: send input logprobs in final chunk
output_data["input_logprobs"] = {
"token_logprobs_val": state.input_token_logprobs_val,
"token_logprobs_idx": state.input_token_logprobs_idx,
"top_logprobs_val": state.input_top_logprobs_val,
"top_logprobs_idx": state.input_top_logprobs_idx,
}
# Send output logprobs if available
if (
state.obj.return_logprob
and batch_out.output_token_logprobs_val
and i < len(batch_out.output_token_logprobs_val)
):
if state.obj.stream:
# For streaming: send incremental logprobs (only new tokens in this chunk)
# NOTE: this is different than TokenizerManager, which always accumulates
def get_part(attr_name):
source_list = getattr(batch_out, attr_name, None)
return (
source_list[i]
if source_list and i < len(source_list)
else []
)
output_data["output_logprobs"] = {
"token_logprobs_val": batch_out.output_token_logprobs_val[i],
"token_logprobs_idx": get_part("output_token_logprobs_idx"),
"top_logprobs_val": get_part("output_top_logprobs_val"),
"top_logprobs_idx": get_part("output_top_logprobs_idx"),
}
elif output_data["finished"]:
# Non-streaming: send cumulative output logprobs in final chunk
output_data["output_logprobs"] = {
"token_logprobs_val": state.output_token_logprobs_val,
"token_logprobs_idx": state.output_token_logprobs_idx,
"top_logprobs_val": state.output_top_logprobs_val,
"top_logprobs_idx": state.output_top_logprobs_idx,
}
# Update state for accumulation
if output_data["token_ids"]:
state.output_ids.extend(output_data["token_ids"])
await state.out_queue.put(output_data)
# Handle completion
if output_data["finished"]:
state.finished = True
state.finished_time = now
state.stream_finished = True
state.event.set()
# Remove from tracking after a delay
async def cleanup():
await asyncio.sleep(5.0)
if rid in self.rid_to_state:
del self.rid_to_state[rid]
asyncio.create_task(cleanup())
async def _handle_embedding_output(self, batch_out: BatchEmbeddingOut):
"""Handle batch embedding output from scheduler."""
for i, rid in enumerate(batch_out.rids):
if rid not in self.rid_to_state:
continue
state = self.rid_to_state[rid]
# Create result
result = {
"request_id": rid,
"embedding": batch_out.embeddings[i],
"prompt_tokens": (
batch_out.prompt_tokens[i] if batch_out.prompt_tokens else 0
),
"finish_reason": (
batch_out.finish_reason[i] if batch_out.finish_reason else None
),
}
# Send result
await state.out_queue.put(result)
# Mark as finished
state.finished = True
state.finished_time = time.time()
state.event.set()
async def _handle_health_check_output(self, health_out: HealthCheckOutput):
"""Handle health check output from scheduler."""
rid = health_out.rid
if rid not in self.rid_to_state:
logger.warning(f"Health check output for unknown request: {rid}")
return
state = self.rid_to_state[rid]
# Create health check result
result = {
"request_id": rid,
"healthy": True, # If we got a response, scheduler is healthy
"output_text": (
health_out.output_str if hasattr(health_out, "output_str") else ""
),
"finish_reason": (
health_out.finish_reason
if hasattr(health_out, "finish_reason")
else "stop"
),
}
# Send result
await state.out_queue.put(result)
# Mark as finished
state.finished = True
state.finished_time = time.time()
state.event.set()
async def _send_to_scheduler(self, obj):
"""Send an object to the scheduler via ZMQ."""
try:
self.send_to_scheduler.send_pyobj(obj)
except Exception as e:
logger.error(f"Failed to send to scheduler: {e}")
raise
def record_request_for_crash_dump(self, obj):
"""Record request for potential crash dump."""
if len(self.crash_dump_request_list) < 100:
self.crash_dump_request_list.append(
{
"time": time.time(),
"request_id": getattr(obj, "rid", "unknown"),
"type": type(obj).__name__,
}
)
async def shutdown(self):
"""Gracefully shutdown the request manager."""
logger.info("Shutting down GrpcRequestManager")
self.gracefully_exit = True
# Cancel all asyncio tasks FIRST - this will interrupt blocked recv() calls
for task in list(self.asyncio_tasks):
if not task.done():
task.cancel()
# Give tasks a moment to process cancellation
if self.asyncio_tasks:
await asyncio.gather(*list(self.asyncio_tasks), return_exceptions=True)
# Cancel all pending requests
for rid, state in list(self.rid_to_state.items()):
if not state.finished:
await state.out_queue.put(
{"error": "Server shutting down", "shutdown": True}
)
state.finished = True
state.event.set()
# Wait for tasks to complete
if self.asyncio_tasks:
await asyncio.gather(*list(self.asyncio_tasks), return_exceptions=True)
# Shutdown bootstrap server if running
if self.bootstrap_server:
logger.info("Shutting down bootstrap server")
try:
if hasattr(self.bootstrap_server, "shutdown"):
if asyncio.iscoroutinefunction(self.bootstrap_server.shutdown):
await self.bootstrap_server.shutdown()
else:
self.bootstrap_server.shutdown()
except Exception as e:
logger.warning(f"Error shutting down bootstrap server: {e}")
# Close ZMQ sockets
self.recv_from_scheduler.close()
self.send_to_scheduler.close()
# Terminate the ZMQ context - this is critical for asyncio loop to exit cleanly
self.context.term()
logger.info("GrpcRequestManager shutdown complete")
def get_server_info(self) -> Dict[str, Any]:
"""Get server information for health checks."""
return {
"active_requests": len(self.rid_to_state),
"paused": self.is_pause,
"last_receive_time": self.last_receive_tstamp,
}
def auto_create_handle_loop(self):
"""Automatically create and start the handle_loop task, matching TokenizerManager pattern."""
if self.no_create_loop:
return
self.no_create_loop = True
loop = asyncio.get_event_loop()
self.asyncio_tasks.add(
loop.create_task(print_exception_wrapper(self.handle_loop))
)
self.event_loop = loop
# We cannot add signal handler when the grpc manager is not in
# the main thread due to the CPython limitation.
if threading.current_thread() is threading.main_thread():
signal_handler = GrpcSignalHandler(self)
loop.add_signal_handler(signal.SIGTERM, signal_handler.sigterm_handler)
# Update the signal handler for the process. It overrides the sigquit handler in the launch phase.
loop.add_signal_handler(
signal.SIGQUIT, signal_handler.running_phase_sigquit_handler
)
else:
logger.warning(
"Signal handler is not added because the grpc request manager is "
"not in the main thread. This disables graceful shutdown of the "
"grpc request manager when SIGTERM is received."
)
self.asyncio_tasks.add(
loop.create_task(print_exception_wrapper(self.sigterm_watchdog))
)
async def sigterm_watchdog(self):
"""Watchdog to handle SIGTERM gracefully, matching TokenizerManager pattern."""
while not self.gracefully_exit:
await asyncio.sleep(1.0)
async def print_exception_wrapper(func):
"""
Sometimes an asyncio function does not print exception.
We do another wrapper to handle the exception.
"""
try:
await func()
except Exception:
traceback = get_exception_traceback()
logger.error(f"GrpcRequestManager hit an exception: {traceback}")
if hasattr(func, "__self__") and isinstance(func.__self__, GrpcRequestManager):
func.__self__.dump_requests_before_crash()
kill_process_tree(os.getpid(), include_parent=True)
sys.exit(1)