856 lines
32 KiB
Python
856 lines
32 KiB
Python
"""
|
|
gRPC Request Manager - Orchestrates request lifecycle without tokenization.
|
|
Mimics TokenizerManager's state management and ZMQ communication patterns.
|
|
"""
|
|
|
|
import asyncio
|
|
import copy
|
|
import dataclasses
|
|
import logging
|
|
import os
|
|
import signal
|
|
import sys
|
|
import threading
|
|
import time
|
|
import uuid
|
|
from typing import Any, AsyncGenerator, Dict, List, Optional, Union
|
|
|
|
import grpc
|
|
import zmq
|
|
import zmq.asyncio
|
|
|
|
from sglang.srt.managers.disagg_service import start_disagg_service
|
|
from sglang.srt.managers.io_struct import (
|
|
AbortReq,
|
|
BatchEmbeddingOut,
|
|
BatchTokenIDOut,
|
|
HealthCheckOutput,
|
|
TokenizedEmbeddingReqInput,
|
|
TokenizedGenerateReqInput,
|
|
)
|
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
|
from sglang.srt.utils import get_zmq_socket, kill_process_tree
|
|
from sglang.utils import get_exception_traceback
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class GrpcSignalHandler:
|
|
"""Minimal signal handler for gRPC server - delegates real crash handling to scheduler."""
|
|
|
|
def __init__(self, grpc_manager):
|
|
self.grpc_manager = grpc_manager
|
|
|
|
def sigterm_handler(self, signum=None, frame=None):
|
|
"""Handle SIGTERM by gracefully shutting down gRPC server."""
|
|
logger.warning(
|
|
f"SIGTERM received. {signum=} {frame=}. Shutting down gRPC server..."
|
|
)
|
|
self.grpc_manager.gracefully_exit = True
|
|
|
|
def running_phase_sigquit_handler(self, signum=None, frame=None):
|
|
"""Handle SIGQUIT from failed scheduler process."""
|
|
logger.error(
|
|
"Received SIGQUIT from scheduler process. Scheduler failed, shutting down gRPC server."
|
|
)
|
|
logger.info(
|
|
"Note: Crash dumps are handled by the scheduler process, not the gRPC server."
|
|
)
|
|
# Just exit cleanly - the scheduler handles crash dumps
|
|
kill_process_tree(os.getpid(), include_parent=True)
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class GrpcReqState:
|
|
"""State tracking for a gRPC request."""
|
|
|
|
# Request identification
|
|
request_id: str
|
|
grpc_context: Optional[grpc.aio.ServicerContext]
|
|
|
|
# Communication
|
|
out_queue: asyncio.Queue
|
|
finished: bool
|
|
event: asyncio.Event
|
|
obj: Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput]
|
|
|
|
# Metrics (same as TokenizerManager's ReqState)
|
|
created_time: float
|
|
finished_time: float = 0.0
|
|
first_token_time: float = 0.0
|
|
last_time: float = 0.0
|
|
last_completion_tokens: int = 1
|
|
|
|
# Streaming state
|
|
stream_finished: bool = False
|
|
input_logprobs_sent: bool = False # Track if input logprobs were sent in streaming
|
|
|
|
# Token accumulation (for non-streaming)
|
|
output_ids: List[int] = dataclasses.field(default_factory=list)
|
|
input_token_logprobs_val: List[float] = dataclasses.field(default_factory=list)
|
|
input_token_logprobs_idx: List[int] = dataclasses.field(default_factory=list)
|
|
output_token_logprobs_val: List[float] = dataclasses.field(default_factory=list)
|
|
output_token_logprobs_idx: List[int] = dataclasses.field(default_factory=list)
|
|
input_top_logprobs_val: List[List[float]] = dataclasses.field(default_factory=list)
|
|
input_top_logprobs_idx: List[List[int]] = dataclasses.field(default_factory=list)
|
|
output_top_logprobs_val: List[List[float]] = dataclasses.field(default_factory=list)
|
|
output_top_logprobs_idx: List[List[int]] = dataclasses.field(default_factory=list)
|
|
|
|
# Session state
|
|
session_id: Optional[str] = None
|
|
is_session_request: bool = False
|
|
|
|
|
|
class GrpcRequestManager:
|
|
"""
|
|
Manages gRPC request lifecycle, mimicking TokenizerManager's orchestration
|
|
behaviors without tokenization.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
server_args: ServerArgs,
|
|
port_args: PortArgs,
|
|
):
|
|
"""Initialize the gRPC request manager."""
|
|
self.server_args = server_args
|
|
self.port_args = port_args
|
|
|
|
# ZMQ Communication Setup (same pattern as TokenizerManager)
|
|
self.context = zmq.asyncio.Context(2)
|
|
|
|
# Socket for receiving outputs from scheduler
|
|
self.recv_from_scheduler = get_zmq_socket(
|
|
self.context, zmq.PULL, port_args.detokenizer_ipc_name, bind=True
|
|
)
|
|
|
|
# Socket for sending requests to scheduler
|
|
self.send_to_scheduler = get_zmq_socket(
|
|
self.context, zmq.PUSH, port_args.scheduler_input_ipc_name, bind=True
|
|
)
|
|
|
|
# State Management (from TokenizerManager)
|
|
self.rid_to_state: Dict[str, GrpcReqState] = {}
|
|
self.asyncio_tasks: set = set()
|
|
self.gracefully_exit = False
|
|
self.no_create_loop = False
|
|
self.event_loop = None
|
|
|
|
# Pause/Resume Control
|
|
self.is_pause = False
|
|
self.is_pause_cond = asyncio.Condition()
|
|
|
|
# Metrics
|
|
self.last_receive_tstamp = time.time()
|
|
|
|
# Crash dump for debugging
|
|
self.crash_dump_request_list = []
|
|
self.crash_dump_performed = False
|
|
|
|
# Bootstrap server for disaggregation mode
|
|
self.bootstrap_server = start_disagg_service(server_args)
|
|
|
|
logger.info(
|
|
f"GrpcRequestManager initialized with ZMQ IPC: "
|
|
f"recv={port_args.detokenizer_ipc_name}, "
|
|
f"send={port_args.scheduler_input_ipc_name}"
|
|
)
|
|
if self.bootstrap_server:
|
|
logger.info(
|
|
f"Bootstrap server started for disaggregation mode: "
|
|
f"{server_args.disaggregation_mode}"
|
|
)
|
|
|
|
async def generate_request(
|
|
self,
|
|
obj: TokenizedGenerateReqInput,
|
|
request_id: Optional[str] = None,
|
|
grpc_context: Optional[grpc.aio.ServicerContext] = None,
|
|
) -> AsyncGenerator[Union[Dict, List[Dict]], None]:
|
|
"""
|
|
Submit a generation request to the scheduler with n>1 parallel sampling support.
|
|
|
|
This method implements the same two-phase approach as tokenizer_manager.py:
|
|
1. Phase 1: Send prefix caching request (max_new_tokens=0)
|
|
2. Phase 2: Send n generation requests that reuse the cached prefix
|
|
|
|
Yields individual responses for streaming, or aggregated responses for non-streaming.
|
|
"""
|
|
n = getattr(obj.sampling_params, "n", 1)
|
|
|
|
if n <= 1:
|
|
async for response in self._handle_single_request(
|
|
obj, request_id, grpc_context
|
|
):
|
|
yield response
|
|
return
|
|
|
|
# N>1 handling - two-phase approach
|
|
logger.debug(f"Multiple sampling request (n={n}), using two-phase approach")
|
|
|
|
# Generate base request ID if not provided
|
|
if request_id is None:
|
|
base_request_id = f"grpc-{uuid.uuid4().hex}"
|
|
else:
|
|
base_request_id = request_id
|
|
|
|
# Phase 1: Cache the common prefix
|
|
logger.debug(f"Phase 1: Caching prefix for request {base_request_id}")
|
|
prefix_obj = copy.copy(obj)
|
|
prefix_obj.sampling_params = copy.copy(obj.sampling_params)
|
|
prefix_obj.sampling_params.max_new_tokens = 0 # Prefill-only
|
|
prefix_obj.sampling_params.n = 1 # Don't replicate prefix request
|
|
|
|
# Send prefix caching request and consume response
|
|
async for _ in self._handle_single_request(
|
|
prefix_obj, f"{base_request_id}-prefix", grpc_context
|
|
):
|
|
# Consume prefix response (usually just one chunk with finish_reason)
|
|
pass
|
|
|
|
logger.debug(f"Phase 1 completed: Prefix cached for {base_request_id}")
|
|
|
|
# Phase 2: Generate n parallel requests
|
|
logger.debug(f"Phase 2: Generating {n} parallel requests")
|
|
generators = []
|
|
request_ids = []
|
|
|
|
for i in range(n):
|
|
# Create individual generation request
|
|
gen_obj = copy.copy(obj)
|
|
gen_obj.sampling_params = copy.copy(obj.sampling_params)
|
|
gen_obj.sampling_params.n = 1 # Each request generates 1 response
|
|
|
|
gen_request_id = f"{base_request_id}-{i}"
|
|
request_ids.append(gen_request_id)
|
|
|
|
# Start generation request
|
|
generators.append(
|
|
self._handle_single_request(gen_obj, gen_request_id, grpc_context)
|
|
)
|
|
|
|
# Handle response aggregation
|
|
is_stream = getattr(obj, "stream", False)
|
|
|
|
if not is_stream:
|
|
# Non-streaming: collect all responses and return as batch
|
|
logger.debug(f"Non-streaming mode: collecting {n} responses")
|
|
responses = []
|
|
for generator in generators:
|
|
async for response in generator:
|
|
responses.append(response)
|
|
yield responses # Return all responses as a batch
|
|
else:
|
|
# Streaming mode: multiplex responses with index for ordering
|
|
logger.debug(f"Streaming mode: multiplexing {n} streams")
|
|
rid_to_index = {rid: i for i, rid in enumerate(request_ids)}
|
|
|
|
# Create async tasks for all generators
|
|
task_map = {}
|
|
for generator in generators:
|
|
task = asyncio.create_task(generator.__anext__())
|
|
task_map[task] = generator
|
|
|
|
# Process responses as they arrive
|
|
while task_map:
|
|
done, _ = await asyncio.wait(
|
|
task_map.keys(), return_when=asyncio.FIRST_COMPLETED
|
|
)
|
|
|
|
for task in done:
|
|
generator = task_map.pop(task)
|
|
try:
|
|
response = await task
|
|
|
|
# Add index for client-side ordering
|
|
if isinstance(response, dict) and "meta_info" in response:
|
|
response_rid = response["meta_info"].get("id", "")
|
|
if response_rid in rid_to_index:
|
|
response["index"] = rid_to_index[response_rid]
|
|
|
|
yield response
|
|
|
|
# Create next task for this generator
|
|
next_task = asyncio.create_task(generator.__anext__())
|
|
task_map[next_task] = generator
|
|
|
|
except StopAsyncIteration:
|
|
# This generator is finished
|
|
pass
|
|
|
|
async def _handle_single_request(
|
|
self,
|
|
obj: TokenizedGenerateReqInput,
|
|
request_id: Optional[str] = None,
|
|
grpc_context: Optional[grpc.aio.ServicerContext] = None,
|
|
):
|
|
"""Handle a single request - core implementation without n>1 logic."""
|
|
# Generate request ID if not provided
|
|
if request_id is None:
|
|
request_id = f"grpc-{uuid.uuid4().hex}"
|
|
|
|
obj.rid = request_id
|
|
|
|
# Create and register request state
|
|
# TODO: support log_request
|
|
state = GrpcReqState(
|
|
request_id=request_id,
|
|
grpc_context=grpc_context,
|
|
out_queue=asyncio.Queue(),
|
|
finished=False,
|
|
event=asyncio.Event(),
|
|
obj=obj,
|
|
created_time=time.time(),
|
|
)
|
|
|
|
# Track session if needed
|
|
if hasattr(obj, "session_params") and obj.session_params:
|
|
state.session_id = obj.session_params.session_id
|
|
state.is_session_request = True
|
|
|
|
self.rid_to_state[request_id] = state
|
|
self.record_request_for_crash_dump(obj)
|
|
|
|
try:
|
|
# Send to scheduler - let exceptions bubble up to grpc_server.py
|
|
await self._send_to_scheduler(obj)
|
|
|
|
is_stream = getattr(obj, "stream", False)
|
|
|
|
while True:
|
|
# Client cancelled - notify scheduler and exit
|
|
if grpc_context and grpc_context.cancelled():
|
|
await self.abort_request(request_id)
|
|
return
|
|
|
|
try:
|
|
response = await asyncio.wait_for(state.out_queue.get(), timeout=4)
|
|
|
|
if is_stream:
|
|
yield response
|
|
|
|
# Non-streaming: yield final response with accumulated tokens from state
|
|
if isinstance(response, dict) and response.get("finished", False):
|
|
if not is_stream:
|
|
final_response = response.copy()
|
|
final_response["token_ids"] = state.output_ids
|
|
yield final_response
|
|
break
|
|
|
|
except asyncio.TimeoutError:
|
|
# Timeout waiting for response - abort and cleanup
|
|
logger.warning(
|
|
f"Timeout waiting for response for request {request_id}"
|
|
)
|
|
await self.abort_request(request_id)
|
|
return
|
|
|
|
finally:
|
|
# Always clean up request state when exiting
|
|
self._cleanup_request_state(request_id)
|
|
|
|
def _cleanup_request_state(self, request_id: str):
|
|
"""Clean up local request state (does not notify scheduler)."""
|
|
if request_id in self.rid_to_state:
|
|
del self.rid_to_state[request_id]
|
|
|
|
async def embedding_request(
|
|
self,
|
|
obj: TokenizedEmbeddingReqInput,
|
|
request_id: Optional[str] = None,
|
|
) -> asyncio.Future:
|
|
"""
|
|
Submit an embedding request to the scheduler.
|
|
Returns a future that will contain the embedding result.
|
|
"""
|
|
# Generate request ID if not provided
|
|
if request_id is None:
|
|
request_id = f"grpc-embed-{uuid.uuid4().hex}"
|
|
|
|
obj.rid = request_id
|
|
|
|
# Create request state
|
|
state = GrpcReqState(
|
|
request_id=request_id,
|
|
grpc_context=None,
|
|
out_queue=asyncio.Queue(),
|
|
finished=False,
|
|
event=asyncio.Event(),
|
|
obj=obj,
|
|
created_time=time.time(),
|
|
)
|
|
|
|
# Register state
|
|
self.rid_to_state[request_id] = state
|
|
|
|
# Create future for result
|
|
future = asyncio.Future()
|
|
|
|
# Send to scheduler
|
|
try:
|
|
await self._send_to_scheduler(obj)
|
|
except Exception as e:
|
|
del self.rid_to_state[request_id]
|
|
future.set_exception(e)
|
|
return future
|
|
|
|
# Wait for result in background
|
|
async def wait_for_result():
|
|
try:
|
|
# Wait for completion
|
|
await state.event.wait()
|
|
# Get result from queue
|
|
result = await state.out_queue.get()
|
|
future.set_result(result)
|
|
except Exception as e:
|
|
future.set_exception(e)
|
|
finally:
|
|
# Clean up
|
|
if request_id in self.rid_to_state:
|
|
del self.rid_to_state[request_id]
|
|
|
|
asyncio.create_task(wait_for_result())
|
|
return future
|
|
|
|
async def abort_request(self, request_id: str) -> bool:
|
|
"""Abort a running request."""
|
|
if request_id not in self.rid_to_state:
|
|
return False
|
|
|
|
# Send abort to scheduler
|
|
abort_req = AbortReq(rid=request_id)
|
|
try:
|
|
await self._send_to_scheduler(abort_req)
|
|
except Exception as e:
|
|
logger.error(f"Failed to send abort request: {e}")
|
|
return False
|
|
|
|
# Mark as finished
|
|
state = self.rid_to_state.get(request_id)
|
|
if state:
|
|
state.finished = True
|
|
state.stream_finished = True
|
|
state.event.set()
|
|
|
|
# Send abort notification to output queue
|
|
await state.out_queue.put({"error": "Request aborted", "abort": True})
|
|
|
|
return True
|
|
|
|
async def pause_generation(self):
|
|
"""Pause generation processing."""
|
|
async with self.is_pause_cond:
|
|
self.is_pause = True
|
|
logger.info("Generation paused")
|
|
|
|
async def resume_generation(self):
|
|
"""Resume generation processing."""
|
|
async with self.is_pause_cond:
|
|
self.is_pause = False
|
|
self.is_pause_cond.notify_all()
|
|
logger.info("Generation resumed")
|
|
|
|
async def handle_loop(self):
|
|
"""
|
|
Main event loop - processes outputs from scheduler.
|
|
Mimics TokenizerManager's handle_loop.
|
|
"""
|
|
while not self.gracefully_exit:
|
|
try:
|
|
# Receive from scheduler
|
|
recv_obj = await self.recv_from_scheduler.recv_pyobj()
|
|
self.last_receive_tstamp = time.time()
|
|
|
|
# Check for pause
|
|
async with self.is_pause_cond:
|
|
while self.is_pause:
|
|
await self.is_pause_cond.wait()
|
|
|
|
# Handle different output types
|
|
if isinstance(recv_obj, BatchTokenIDOut):
|
|
await self._handle_batch_output(recv_obj)
|
|
elif isinstance(recv_obj, BatchEmbeddingOut):
|
|
await self._handle_embedding_output(recv_obj)
|
|
elif isinstance(recv_obj, HealthCheckOutput):
|
|
await self._handle_health_check_output(recv_obj)
|
|
else:
|
|
logger.warning(f"Unknown output type: {type(recv_obj)}")
|
|
|
|
except zmq.error.Again:
|
|
# Timeout, check if we should exit
|
|
if self.gracefully_exit:
|
|
break
|
|
continue
|
|
except zmq.error.ZMQError as e:
|
|
# Socket closed or other ZMQ error - exit cleanly if shutting down
|
|
if self.gracefully_exit:
|
|
logger.debug(f"ZMQ recv interrupted during shutdown: {e}")
|
|
break
|
|
logger.error(
|
|
f"ZMQ error in handle loop: {e}\n{get_exception_traceback()}"
|
|
)
|
|
break
|
|
except Exception as e:
|
|
logger.error(f"Handle loop error: {e}\n{get_exception_traceback()}")
|
|
if self.gracefully_exit:
|
|
break
|
|
|
|
def _convert_logprob_style(
|
|
self,
|
|
state: GrpcReqState,
|
|
batch_out: BatchTokenIDOut,
|
|
batch_index: int,
|
|
):
|
|
"""
|
|
Convert and accumulate logprobs from batch output to state.
|
|
Follows the same logic as tokenizer_manager.convert_logprob_style.
|
|
"""
|
|
# Early exit if no input logprobs at all
|
|
if batch_out.input_token_logprobs_val is None:
|
|
return
|
|
|
|
# Accumulate input token logprobs (only if list is non-empty)
|
|
if len(batch_out.input_token_logprobs_val) > 0:
|
|
state.input_token_logprobs_val.extend(
|
|
batch_out.input_token_logprobs_val[batch_index]
|
|
)
|
|
state.input_token_logprobs_idx.extend(
|
|
batch_out.input_token_logprobs_idx[batch_index]
|
|
)
|
|
|
|
# Always accumulate output token logprobs
|
|
state.output_token_logprobs_val.extend(
|
|
batch_out.output_token_logprobs_val[batch_index]
|
|
)
|
|
state.output_token_logprobs_idx.extend(
|
|
batch_out.output_token_logprobs_idx[batch_index]
|
|
)
|
|
|
|
# Handle top logprobs if requested
|
|
if state.obj.top_logprobs_num > 0:
|
|
# Accumulate input top logprobs (only if list is non-empty)
|
|
if len(batch_out.input_top_logprobs_val) > 0:
|
|
state.input_top_logprobs_val.extend(
|
|
batch_out.input_top_logprobs_val[batch_index]
|
|
)
|
|
state.input_top_logprobs_idx.extend(
|
|
batch_out.input_top_logprobs_idx[batch_index]
|
|
)
|
|
|
|
# Always accumulate output top logprobs
|
|
state.output_top_logprobs_val.extend(
|
|
batch_out.output_top_logprobs_val[batch_index]
|
|
)
|
|
state.output_top_logprobs_idx.extend(
|
|
batch_out.output_top_logprobs_idx[batch_index]
|
|
)
|
|
|
|
async def _handle_batch_output(self, batch_out: BatchTokenIDOut):
|
|
"""Handle batch generation output from scheduler."""
|
|
# Process each request in the batch
|
|
for i, rid in enumerate(batch_out.rids):
|
|
if rid not in self.rid_to_state:
|
|
continue
|
|
|
|
state = self.rid_to_state[rid]
|
|
|
|
# Update metrics
|
|
now = time.time()
|
|
if state.first_token_time == 0.0:
|
|
state.first_token_time = now
|
|
state.last_time = now
|
|
|
|
# Extract output for this request
|
|
output_data = {
|
|
"request_id": rid,
|
|
"token_ids": batch_out.output_ids[i] if batch_out.output_ids else [],
|
|
"finished": batch_out.finished_reasons[i] is not None,
|
|
"meta_info": {
|
|
"prompt_tokens": (
|
|
batch_out.prompt_tokens[i] if batch_out.prompt_tokens else 0
|
|
),
|
|
"completion_tokens": (
|
|
batch_out.completion_tokens[i]
|
|
if batch_out.completion_tokens
|
|
else 0
|
|
),
|
|
"cached_tokens": (
|
|
batch_out.cached_tokens[i] if batch_out.cached_tokens else 0
|
|
),
|
|
"finish_reason": (
|
|
str(batch_out.finished_reasons[i])
|
|
if batch_out.finished_reasons[i]
|
|
else None
|
|
),
|
|
},
|
|
}
|
|
|
|
# Accumulate logprobs (following tokenizer_manager pattern)
|
|
if state.obj.return_logprob:
|
|
self._convert_logprob_style(state, batch_out, i)
|
|
|
|
# Send input logprobs based if available
|
|
if (
|
|
state.obj.return_logprob
|
|
and state.obj.logprob_start_len >= 0
|
|
and state.input_token_logprobs_val
|
|
):
|
|
if state.obj.stream and not state.input_logprobs_sent:
|
|
# Streaming: send input logprobs once in first chunk that has them
|
|
output_data["input_logprobs"] = {
|
|
"token_logprobs_val": state.input_token_logprobs_val,
|
|
"token_logprobs_idx": state.input_token_logprobs_idx,
|
|
"top_logprobs_val": state.input_top_logprobs_val,
|
|
"top_logprobs_idx": state.input_top_logprobs_idx,
|
|
}
|
|
state.input_logprobs_sent = True
|
|
elif not state.obj.stream and output_data["finished"]:
|
|
# Non-streaming: send input logprobs in final chunk
|
|
output_data["input_logprobs"] = {
|
|
"token_logprobs_val": state.input_token_logprobs_val,
|
|
"token_logprobs_idx": state.input_token_logprobs_idx,
|
|
"top_logprobs_val": state.input_top_logprobs_val,
|
|
"top_logprobs_idx": state.input_top_logprobs_idx,
|
|
}
|
|
|
|
# Send output logprobs if available
|
|
if (
|
|
state.obj.return_logprob
|
|
and batch_out.output_token_logprobs_val
|
|
and i < len(batch_out.output_token_logprobs_val)
|
|
):
|
|
if state.obj.stream:
|
|
# For streaming: send incremental logprobs (only new tokens in this chunk)
|
|
# NOTE: this is different than TokenizerManager, which always accumulates
|
|
def get_part(attr_name):
|
|
source_list = getattr(batch_out, attr_name, None)
|
|
return (
|
|
source_list[i]
|
|
if source_list and i < len(source_list)
|
|
else []
|
|
)
|
|
|
|
output_data["output_logprobs"] = {
|
|
"token_logprobs_val": batch_out.output_token_logprobs_val[i],
|
|
"token_logprobs_idx": get_part("output_token_logprobs_idx"),
|
|
"top_logprobs_val": get_part("output_top_logprobs_val"),
|
|
"top_logprobs_idx": get_part("output_top_logprobs_idx"),
|
|
}
|
|
elif output_data["finished"]:
|
|
# Non-streaming: send cumulative output logprobs in final chunk
|
|
output_data["output_logprobs"] = {
|
|
"token_logprobs_val": state.output_token_logprobs_val,
|
|
"token_logprobs_idx": state.output_token_logprobs_idx,
|
|
"top_logprobs_val": state.output_top_logprobs_val,
|
|
"top_logprobs_idx": state.output_top_logprobs_idx,
|
|
}
|
|
|
|
# Update state for accumulation
|
|
if output_data["token_ids"]:
|
|
state.output_ids.extend(output_data["token_ids"])
|
|
|
|
await state.out_queue.put(output_data)
|
|
|
|
# Handle completion
|
|
if output_data["finished"]:
|
|
state.finished = True
|
|
state.finished_time = now
|
|
state.stream_finished = True
|
|
state.event.set()
|
|
|
|
# Remove from tracking after a delay
|
|
async def cleanup():
|
|
await asyncio.sleep(5.0)
|
|
if rid in self.rid_to_state:
|
|
del self.rid_to_state[rid]
|
|
|
|
asyncio.create_task(cleanup())
|
|
|
|
async def _handle_embedding_output(self, batch_out: BatchEmbeddingOut):
|
|
"""Handle batch embedding output from scheduler."""
|
|
for i, rid in enumerate(batch_out.rids):
|
|
if rid not in self.rid_to_state:
|
|
continue
|
|
|
|
state = self.rid_to_state[rid]
|
|
|
|
# Create result
|
|
result = {
|
|
"request_id": rid,
|
|
"embedding": batch_out.embeddings[i],
|
|
"prompt_tokens": (
|
|
batch_out.prompt_tokens[i] if batch_out.prompt_tokens else 0
|
|
),
|
|
"finish_reason": (
|
|
batch_out.finish_reason[i] if batch_out.finish_reason else None
|
|
),
|
|
}
|
|
|
|
# Send result
|
|
await state.out_queue.put(result)
|
|
|
|
# Mark as finished
|
|
state.finished = True
|
|
state.finished_time = time.time()
|
|
state.event.set()
|
|
|
|
async def _handle_health_check_output(self, health_out: HealthCheckOutput):
|
|
"""Handle health check output from scheduler."""
|
|
rid = health_out.rid
|
|
|
|
if rid not in self.rid_to_state:
|
|
logger.warning(f"Health check output for unknown request: {rid}")
|
|
return
|
|
|
|
state = self.rid_to_state[rid]
|
|
|
|
# Create health check result
|
|
result = {
|
|
"request_id": rid,
|
|
"healthy": True, # If we got a response, scheduler is healthy
|
|
"output_text": (
|
|
health_out.output_str if hasattr(health_out, "output_str") else ""
|
|
),
|
|
"finish_reason": (
|
|
health_out.finish_reason
|
|
if hasattr(health_out, "finish_reason")
|
|
else "stop"
|
|
),
|
|
}
|
|
|
|
# Send result
|
|
await state.out_queue.put(result)
|
|
|
|
# Mark as finished
|
|
state.finished = True
|
|
state.finished_time = time.time()
|
|
state.event.set()
|
|
|
|
async def _send_to_scheduler(self, obj):
|
|
"""Send an object to the scheduler via ZMQ."""
|
|
try:
|
|
self.send_to_scheduler.send_pyobj(obj)
|
|
except Exception as e:
|
|
logger.error(f"Failed to send to scheduler: {e}")
|
|
raise
|
|
|
|
def record_request_for_crash_dump(self, obj):
|
|
"""Record request for potential crash dump."""
|
|
if len(self.crash_dump_request_list) < 100:
|
|
self.crash_dump_request_list.append(
|
|
{
|
|
"time": time.time(),
|
|
"request_id": getattr(obj, "rid", "unknown"),
|
|
"type": type(obj).__name__,
|
|
}
|
|
)
|
|
|
|
async def shutdown(self):
|
|
"""Gracefully shutdown the request manager."""
|
|
logger.info("Shutting down GrpcRequestManager")
|
|
self.gracefully_exit = True
|
|
|
|
# Cancel all asyncio tasks FIRST - this will interrupt blocked recv() calls
|
|
for task in list(self.asyncio_tasks):
|
|
if not task.done():
|
|
task.cancel()
|
|
|
|
# Give tasks a moment to process cancellation
|
|
if self.asyncio_tasks:
|
|
await asyncio.gather(*list(self.asyncio_tasks), return_exceptions=True)
|
|
|
|
# Cancel all pending requests
|
|
for rid, state in list(self.rid_to_state.items()):
|
|
if not state.finished:
|
|
await state.out_queue.put(
|
|
{"error": "Server shutting down", "shutdown": True}
|
|
)
|
|
state.finished = True
|
|
state.event.set()
|
|
|
|
# Wait for tasks to complete
|
|
if self.asyncio_tasks:
|
|
await asyncio.gather(*list(self.asyncio_tasks), return_exceptions=True)
|
|
|
|
# Shutdown bootstrap server if running
|
|
if self.bootstrap_server:
|
|
logger.info("Shutting down bootstrap server")
|
|
try:
|
|
if hasattr(self.bootstrap_server, "shutdown"):
|
|
if asyncio.iscoroutinefunction(self.bootstrap_server.shutdown):
|
|
await self.bootstrap_server.shutdown()
|
|
else:
|
|
self.bootstrap_server.shutdown()
|
|
except Exception as e:
|
|
logger.warning(f"Error shutting down bootstrap server: {e}")
|
|
|
|
# Close ZMQ sockets
|
|
self.recv_from_scheduler.close()
|
|
self.send_to_scheduler.close()
|
|
|
|
# Terminate the ZMQ context - this is critical for asyncio loop to exit cleanly
|
|
self.context.term()
|
|
|
|
logger.info("GrpcRequestManager shutdown complete")
|
|
|
|
def get_server_info(self) -> Dict[str, Any]:
|
|
"""Get server information for health checks."""
|
|
return {
|
|
"active_requests": len(self.rid_to_state),
|
|
"paused": self.is_pause,
|
|
"last_receive_time": self.last_receive_tstamp,
|
|
}
|
|
|
|
def auto_create_handle_loop(self):
|
|
"""Automatically create and start the handle_loop task, matching TokenizerManager pattern."""
|
|
if self.no_create_loop:
|
|
return
|
|
|
|
self.no_create_loop = True
|
|
loop = asyncio.get_event_loop()
|
|
self.asyncio_tasks.add(
|
|
loop.create_task(print_exception_wrapper(self.handle_loop))
|
|
)
|
|
|
|
self.event_loop = loop
|
|
|
|
# We cannot add signal handler when the grpc manager is not in
|
|
# the main thread due to the CPython limitation.
|
|
if threading.current_thread() is threading.main_thread():
|
|
signal_handler = GrpcSignalHandler(self)
|
|
loop.add_signal_handler(signal.SIGTERM, signal_handler.sigterm_handler)
|
|
# Update the signal handler for the process. It overrides the sigquit handler in the launch phase.
|
|
loop.add_signal_handler(
|
|
signal.SIGQUIT, signal_handler.running_phase_sigquit_handler
|
|
)
|
|
else:
|
|
logger.warning(
|
|
"Signal handler is not added because the grpc request manager is "
|
|
"not in the main thread. This disables graceful shutdown of the "
|
|
"grpc request manager when SIGTERM is received."
|
|
)
|
|
self.asyncio_tasks.add(
|
|
loop.create_task(print_exception_wrapper(self.sigterm_watchdog))
|
|
)
|
|
|
|
async def sigterm_watchdog(self):
|
|
"""Watchdog to handle SIGTERM gracefully, matching TokenizerManager pattern."""
|
|
while not self.gracefully_exit:
|
|
await asyncio.sleep(1.0)
|
|
|
|
|
|
async def print_exception_wrapper(func):
|
|
"""
|
|
Sometimes an asyncio function does not print exception.
|
|
We do another wrapper to handle the exception.
|
|
"""
|
|
try:
|
|
await func()
|
|
except Exception:
|
|
traceback = get_exception_traceback()
|
|
logger.error(f"GrpcRequestManager hit an exception: {traceback}")
|
|
if hasattr(func, "__self__") and isinstance(func.__self__, GrpcRequestManager):
|
|
func.__self__.dump_requests_before_crash()
|
|
kill_process_tree(os.getpid(), include_parent=True)
|
|
sys.exit(1)
|