Implement Standalone gRPC Server for SGLang Python Scheduler (#10283)
This commit is contained in:
580
python/sglang/srt/entrypoints/grpc_request_manager.py
Normal file
580
python/sglang/srt/entrypoints/grpc_request_manager.py
Normal file
@@ -0,0 +1,580 @@
|
||||
"""
|
||||
gRPC Request Manager - Orchestrates request lifecycle without tokenization.
|
||||
Mimics TokenizerManager's state management and ZMQ communication patterns.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import dataclasses
|
||||
import logging
|
||||
import os
|
||||
import signal
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import grpc
|
||||
import zmq
|
||||
import zmq.asyncio
|
||||
|
||||
from sglang.srt.managers.io_struct import (
|
||||
AbortReq,
|
||||
BatchEmbeddingOut,
|
||||
BatchTokenIDOut,
|
||||
HealthCheckOutput,
|
||||
TokenizedEmbeddingReqInput,
|
||||
TokenizedGenerateReqInput,
|
||||
)
|
||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||
from sglang.srt.utils import get_zmq_socket, kill_process_tree
|
||||
from sglang.utils import get_exception_traceback
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class GrpcSignalHandler:
|
||||
"""Minimal signal handler for gRPC server - delegates real crash handling to scheduler."""
|
||||
|
||||
def __init__(self, grpc_manager):
|
||||
self.grpc_manager = grpc_manager
|
||||
|
||||
def sigterm_handler(self, signum=None, frame=None):
|
||||
"""Handle SIGTERM by gracefully shutting down gRPC server."""
|
||||
logger.warning(
|
||||
f"SIGTERM received. {signum=} {frame=}. Shutting down gRPC server..."
|
||||
)
|
||||
self.grpc_manager.gracefully_exit = True
|
||||
|
||||
def running_phase_sigquit_handler(self, signum=None, frame=None):
|
||||
"""Handle SIGQUIT from failed scheduler process."""
|
||||
logger.error(
|
||||
"Received SIGQUIT from scheduler process. Scheduler failed, shutting down gRPC server."
|
||||
)
|
||||
logger.info(
|
||||
"Note: Crash dumps are handled by the scheduler process, not the gRPC server."
|
||||
)
|
||||
# Just exit cleanly - the scheduler handles crash dumps
|
||||
kill_process_tree(os.getpid(), include_parent=True)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class GrpcReqState:
|
||||
"""State tracking for a gRPC request."""
|
||||
|
||||
# Request identification
|
||||
request_id: str
|
||||
grpc_context: Optional[grpc.aio.ServicerContext]
|
||||
|
||||
# Communication
|
||||
out_queue: asyncio.Queue
|
||||
finished: bool
|
||||
event: asyncio.Event
|
||||
obj: Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput]
|
||||
|
||||
# Metrics (same as TokenizerManager's ReqState)
|
||||
created_time: float
|
||||
finished_time: float = 0.0
|
||||
first_token_time: float = 0.0
|
||||
last_time: float = 0.0
|
||||
last_completion_tokens: int = 1
|
||||
|
||||
# Streaming state
|
||||
last_output_offset: int = 0
|
||||
stream_finished: bool = False
|
||||
|
||||
# Output accumulation
|
||||
text: str = ""
|
||||
output_ids: List[int] = dataclasses.field(default_factory=list)
|
||||
input_token_logprobs_val: List[float] = dataclasses.field(default_factory=list)
|
||||
input_token_logprobs_idx: List[int] = dataclasses.field(default_factory=list)
|
||||
output_token_logprobs_val: List[float] = dataclasses.field(default_factory=list)
|
||||
output_token_logprobs_idx: List[int] = dataclasses.field(default_factory=list)
|
||||
input_top_logprobs_val: List[List[float]] = dataclasses.field(default_factory=list)
|
||||
input_top_logprobs_idx: List[List[int]] = dataclasses.field(default_factory=list)
|
||||
output_top_logprobs_val: List[List[float]] = dataclasses.field(default_factory=list)
|
||||
output_top_logprobs_idx: List[List[int]] = dataclasses.field(default_factory=list)
|
||||
|
||||
# Session state
|
||||
session_id: Optional[str] = None
|
||||
is_session_request: bool = False
|
||||
|
||||
|
||||
class GrpcRequestManager:
|
||||
"""
|
||||
Manages gRPC request lifecycle, mimicking TokenizerManager's orchestration
|
||||
behaviors without tokenization.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
server_args: ServerArgs,
|
||||
port_args: PortArgs,
|
||||
):
|
||||
"""Initialize the gRPC request manager."""
|
||||
self.server_args = server_args
|
||||
self.port_args = port_args
|
||||
|
||||
# ZMQ Communication Setup (same pattern as TokenizerManager)
|
||||
context = zmq.asyncio.Context(2)
|
||||
|
||||
# Socket for receiving outputs from scheduler
|
||||
self.recv_from_scheduler = get_zmq_socket(
|
||||
context, zmq.PULL, port_args.detokenizer_ipc_name, bind=True
|
||||
)
|
||||
|
||||
# Socket for sending requests to scheduler
|
||||
self.send_to_scheduler = get_zmq_socket(
|
||||
context, zmq.PUSH, port_args.scheduler_input_ipc_name, bind=True
|
||||
)
|
||||
|
||||
# State Management (from TokenizerManager)
|
||||
self.rid_to_state: Dict[str, GrpcReqState] = {}
|
||||
self.asyncio_tasks: set = set()
|
||||
self.gracefully_exit = False
|
||||
self.no_create_loop = False
|
||||
self.event_loop = None
|
||||
|
||||
# Pause/Resume Control
|
||||
self.is_pause = False
|
||||
self.is_pause_cond = asyncio.Condition()
|
||||
|
||||
# Metrics
|
||||
self.request_counter = 0
|
||||
self.request_counter_lock = asyncio.Lock()
|
||||
self.last_receive_tstamp = time.time()
|
||||
|
||||
# Crash dump for debugging
|
||||
self.crash_dump_request_list = []
|
||||
self.crash_dump_performed = False
|
||||
|
||||
logger.info(
|
||||
f"GrpcRequestManager initialized with ZMQ IPC: "
|
||||
f"recv={port_args.detokenizer_ipc_name}, "
|
||||
f"send={port_args.scheduler_input_ipc_name}"
|
||||
)
|
||||
|
||||
async def generate_request(
|
||||
self,
|
||||
obj: TokenizedGenerateReqInput,
|
||||
request_id: Optional[str] = None,
|
||||
grpc_context: Optional[grpc.aio.ServicerContext] = None,
|
||||
) -> asyncio.Queue:
|
||||
"""
|
||||
Submit a generation request to the scheduler.
|
||||
Returns a queue for streaming outputs.
|
||||
"""
|
||||
# Generate request ID if not provided
|
||||
if request_id is None:
|
||||
async with self.request_counter_lock:
|
||||
request_id = f"grpc-{self.request_counter}"
|
||||
self.request_counter += 1
|
||||
|
||||
obj.rid = request_id
|
||||
|
||||
# TODO: support log_request
|
||||
|
||||
# Create request state
|
||||
state = GrpcReqState(
|
||||
request_id=request_id,
|
||||
grpc_context=grpc_context,
|
||||
out_queue=asyncio.Queue(),
|
||||
finished=False,
|
||||
event=asyncio.Event(),
|
||||
obj=obj,
|
||||
created_time=time.time(),
|
||||
)
|
||||
|
||||
# Track session if needed
|
||||
if hasattr(obj, "session_params") and obj.session_params:
|
||||
state.session_id = obj.session_params.session_id
|
||||
state.is_session_request = True
|
||||
|
||||
# Register state
|
||||
self.rid_to_state[request_id] = state
|
||||
self.record_request_for_crash_dump(obj)
|
||||
|
||||
# Send to scheduler via ZMQ
|
||||
try:
|
||||
await self._send_to_scheduler(obj)
|
||||
except Exception as e:
|
||||
# Clean up on failure
|
||||
del self.rid_to_state[request_id]
|
||||
raise RuntimeError(f"Failed to send request to scheduler: {e}")
|
||||
|
||||
return state.out_queue
|
||||
|
||||
async def embedding_request(
|
||||
self,
|
||||
obj: TokenizedEmbeddingReqInput,
|
||||
request_id: Optional[str] = None,
|
||||
) -> asyncio.Future:
|
||||
"""
|
||||
Submit an embedding request to the scheduler.
|
||||
Returns a future that will contain the embedding result.
|
||||
"""
|
||||
# Generate request ID if not provided
|
||||
if request_id is None:
|
||||
async with self.request_counter_lock:
|
||||
request_id = f"grpc-embed-{self.request_counter}"
|
||||
self.request_counter += 1
|
||||
|
||||
obj.rid = request_id
|
||||
|
||||
# Create request state
|
||||
state = GrpcReqState(
|
||||
request_id=request_id,
|
||||
grpc_context=None,
|
||||
out_queue=asyncio.Queue(),
|
||||
finished=False,
|
||||
event=asyncio.Event(),
|
||||
obj=obj,
|
||||
created_time=time.time(),
|
||||
)
|
||||
|
||||
# Register state
|
||||
self.rid_to_state[request_id] = state
|
||||
|
||||
# Create future for result
|
||||
future = asyncio.Future()
|
||||
|
||||
# Send to scheduler
|
||||
try:
|
||||
await self._send_to_scheduler(obj)
|
||||
except Exception as e:
|
||||
del self.rid_to_state[request_id]
|
||||
future.set_exception(e)
|
||||
return future
|
||||
|
||||
# Wait for result in background
|
||||
async def wait_for_result():
|
||||
try:
|
||||
# Wait for completion
|
||||
await state.event.wait()
|
||||
# Get result from queue
|
||||
result = await state.out_queue.get()
|
||||
future.set_result(result)
|
||||
except Exception as e:
|
||||
future.set_exception(e)
|
||||
finally:
|
||||
# Clean up
|
||||
if request_id in self.rid_to_state:
|
||||
del self.rid_to_state[request_id]
|
||||
|
||||
asyncio.create_task(wait_for_result())
|
||||
return future
|
||||
|
||||
async def abort_request(self, request_id: str) -> bool:
|
||||
"""Abort a running request."""
|
||||
if request_id not in self.rid_to_state:
|
||||
return False
|
||||
|
||||
# Send abort to scheduler
|
||||
abort_req = AbortReq(rid=request_id)
|
||||
try:
|
||||
await self._send_to_scheduler(abort_req)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send abort request: {e}")
|
||||
return False
|
||||
|
||||
# Mark as finished
|
||||
state = self.rid_to_state.get(request_id)
|
||||
if state:
|
||||
state.finished = True
|
||||
state.stream_finished = True
|
||||
state.event.set()
|
||||
|
||||
# Send abort notification to output queue
|
||||
await state.out_queue.put({"error": "Request aborted", "abort": True})
|
||||
|
||||
return True
|
||||
|
||||
async def pause_generation(self):
|
||||
"""Pause generation processing."""
|
||||
async with self.is_pause_cond:
|
||||
self.is_pause = True
|
||||
logger.info("Generation paused")
|
||||
|
||||
async def resume_generation(self):
|
||||
"""Resume generation processing."""
|
||||
async with self.is_pause_cond:
|
||||
self.is_pause = False
|
||||
self.is_pause_cond.notify_all()
|
||||
logger.info("Generation resumed")
|
||||
|
||||
async def handle_loop(self):
|
||||
"""
|
||||
Main event loop - processes outputs from scheduler.
|
||||
Mimics TokenizerManager's handle_loop.
|
||||
"""
|
||||
while not self.gracefully_exit:
|
||||
try:
|
||||
# Receive from scheduler
|
||||
recv_obj = await self.recv_from_scheduler.recv_pyobj()
|
||||
self.last_receive_tstamp = time.time()
|
||||
|
||||
# Check for pause
|
||||
async with self.is_pause_cond:
|
||||
while self.is_pause:
|
||||
await self.is_pause_cond.wait()
|
||||
|
||||
# Handle different output types
|
||||
if isinstance(recv_obj, BatchTokenIDOut):
|
||||
await self._handle_batch_output(recv_obj)
|
||||
elif isinstance(recv_obj, BatchEmbeddingOut):
|
||||
await self._handle_embedding_output(recv_obj)
|
||||
elif isinstance(recv_obj, HealthCheckOutput):
|
||||
await self._handle_health_check_output(recv_obj)
|
||||
else:
|
||||
logger.warning(f"Unknown output type: {type(recv_obj)}")
|
||||
|
||||
except zmq.error.Again:
|
||||
# Timeout, check if we should exit
|
||||
if self.gracefully_exit:
|
||||
break
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.error(f"Handle loop error: {e}\n{get_exception_traceback()}")
|
||||
if self.gracefully_exit:
|
||||
break
|
||||
|
||||
async def _handle_batch_output(self, batch_out: BatchTokenIDOut):
|
||||
"""Handle batch generation output from scheduler."""
|
||||
# Process each request in the batch
|
||||
for i, rid in enumerate(batch_out.rids):
|
||||
if rid not in self.rid_to_state:
|
||||
continue
|
||||
|
||||
state = self.rid_to_state[rid]
|
||||
|
||||
# Update metrics
|
||||
now = time.time()
|
||||
if state.first_token_time == 0.0:
|
||||
state.first_token_time = now
|
||||
state.last_time = now
|
||||
|
||||
# Extract output for this request
|
||||
output_data = {
|
||||
"request_id": rid,
|
||||
"text": batch_out.decoded_texts[i] if batch_out.decoded_texts else "",
|
||||
"token_ids": batch_out.output_ids[i] if batch_out.output_ids else [],
|
||||
"finished": batch_out.finished_reasons[i] is not None,
|
||||
"meta_info": {
|
||||
"prompt_tokens": (
|
||||
batch_out.prompt_tokens[i] if batch_out.prompt_tokens else 0
|
||||
),
|
||||
"completion_tokens": (
|
||||
batch_out.completion_tokens[i]
|
||||
if batch_out.completion_tokens
|
||||
else 0
|
||||
),
|
||||
"finish_reason": (
|
||||
str(batch_out.finished_reasons[i])
|
||||
if batch_out.finished_reasons[i]
|
||||
else None
|
||||
),
|
||||
},
|
||||
}
|
||||
|
||||
# Add logprobs if available
|
||||
if batch_out.output_token_logprobs_val and i < len(
|
||||
batch_out.output_token_logprobs_val
|
||||
):
|
||||
output_data["logprobs"] = {
|
||||
"tokens": batch_out.output_token_logprobs_val[i],
|
||||
"top_logprobs": (
|
||||
batch_out.output_top_logprobs_val[i]
|
||||
if batch_out.output_top_logprobs_val
|
||||
and i < len(batch_out.output_top_logprobs_val)
|
||||
else None
|
||||
),
|
||||
}
|
||||
|
||||
# Update state
|
||||
if output_data["text"]:
|
||||
state.text += output_data["text"][state.last_output_offset :]
|
||||
state.last_output_offset = len(output_data["text"])
|
||||
|
||||
if output_data["token_ids"]:
|
||||
state.output_ids.extend(output_data["token_ids"])
|
||||
|
||||
# Send to output queue
|
||||
await state.out_queue.put(output_data)
|
||||
|
||||
# Handle completion
|
||||
if output_data["finished"]:
|
||||
state.finished = True
|
||||
state.finished_time = now
|
||||
state.stream_finished = True
|
||||
state.event.set()
|
||||
|
||||
# Remove from tracking after a delay
|
||||
async def cleanup():
|
||||
await asyncio.sleep(5.0)
|
||||
if rid in self.rid_to_state:
|
||||
del self.rid_to_state[rid]
|
||||
|
||||
asyncio.create_task(cleanup())
|
||||
|
||||
async def _handle_embedding_output(self, batch_out: BatchEmbeddingOut):
|
||||
"""Handle batch embedding output from scheduler."""
|
||||
for i, rid in enumerate(batch_out.rids):
|
||||
if rid not in self.rid_to_state:
|
||||
continue
|
||||
|
||||
state = self.rid_to_state[rid]
|
||||
|
||||
# Create result
|
||||
result = {
|
||||
"request_id": rid,
|
||||
"embedding": batch_out.embeddings[i],
|
||||
"prompt_tokens": (
|
||||
batch_out.prompt_tokens[i] if batch_out.prompt_tokens else 0
|
||||
),
|
||||
"finish_reason": (
|
||||
batch_out.finish_reason[i] if batch_out.finish_reason else None
|
||||
),
|
||||
}
|
||||
|
||||
# Send result
|
||||
await state.out_queue.put(result)
|
||||
|
||||
# Mark as finished
|
||||
state.finished = True
|
||||
state.finished_time = time.time()
|
||||
state.event.set()
|
||||
|
||||
async def _handle_health_check_output(self, health_out: HealthCheckOutput):
|
||||
"""Handle health check output from scheduler."""
|
||||
rid = health_out.rid
|
||||
|
||||
if rid not in self.rid_to_state:
|
||||
logger.warning(f"Health check output for unknown request: {rid}")
|
||||
return
|
||||
|
||||
state = self.rid_to_state[rid]
|
||||
|
||||
# Create health check result
|
||||
result = {
|
||||
"request_id": rid,
|
||||
"healthy": True, # If we got a response, scheduler is healthy
|
||||
"output_text": (
|
||||
health_out.output_str if hasattr(health_out, "output_str") else ""
|
||||
),
|
||||
"finish_reason": (
|
||||
health_out.finish_reason
|
||||
if hasattr(health_out, "finish_reason")
|
||||
else "stop"
|
||||
),
|
||||
}
|
||||
|
||||
# Send result
|
||||
await state.out_queue.put(result)
|
||||
|
||||
# Mark as finished
|
||||
state.finished = True
|
||||
state.finished_time = time.time()
|
||||
state.event.set()
|
||||
|
||||
async def _send_to_scheduler(self, obj):
|
||||
"""Send an object to the scheduler via ZMQ."""
|
||||
try:
|
||||
self.send_to_scheduler.send_pyobj(obj)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send to scheduler: {e}")
|
||||
raise
|
||||
|
||||
def record_request_for_crash_dump(self, obj):
|
||||
"""Record request for potential crash dump."""
|
||||
if len(self.crash_dump_request_list) < 100:
|
||||
self.crash_dump_request_list.append(
|
||||
{
|
||||
"time": time.time(),
|
||||
"request_id": getattr(obj, "rid", "unknown"),
|
||||
"type": type(obj).__name__,
|
||||
}
|
||||
)
|
||||
|
||||
async def shutdown(self):
|
||||
"""Gracefully shutdown the request manager."""
|
||||
logger.info("Shutting down GrpcRequestManager")
|
||||
self.gracefully_exit = True
|
||||
|
||||
# Cancel all pending requests
|
||||
for rid, state in self.rid_to_state.items():
|
||||
if not state.finished:
|
||||
await state.out_queue.put(
|
||||
{"error": "Server shutting down", "shutdown": True}
|
||||
)
|
||||
state.finished = True
|
||||
state.event.set()
|
||||
|
||||
# Wait for tasks to complete
|
||||
if self.asyncio_tasks:
|
||||
await asyncio.gather(*list(self.asyncio_tasks), return_exceptions=True)
|
||||
|
||||
# Close ZMQ sockets
|
||||
self.recv_from_scheduler.close()
|
||||
self.send_to_scheduler.close()
|
||||
|
||||
logger.info("GrpcRequestManager shutdown complete")
|
||||
|
||||
def get_server_info(self) -> Dict[str, Any]:
|
||||
"""Get server information for health checks."""
|
||||
return {
|
||||
"active_requests": len(self.rid_to_state),
|
||||
"paused": self.is_pause,
|
||||
"last_receive_time": self.last_receive_tstamp,
|
||||
}
|
||||
|
||||
def auto_create_handle_loop(self):
|
||||
"""Automatically create and start the handle_loop task, matching TokenizerManager pattern."""
|
||||
if self.no_create_loop:
|
||||
return
|
||||
|
||||
self.no_create_loop = True
|
||||
loop = asyncio.get_event_loop()
|
||||
self.asyncio_tasks.add(
|
||||
loop.create_task(print_exception_wrapper(self.handle_loop))
|
||||
)
|
||||
|
||||
self.event_loop = loop
|
||||
|
||||
# We cannot add signal handler when the grpc manager is not in
|
||||
# the main thread due to the CPython limitation.
|
||||
if threading.current_thread() is threading.main_thread():
|
||||
signal_handler = GrpcSignalHandler(self)
|
||||
loop.add_signal_handler(signal.SIGTERM, signal_handler.sigterm_handler)
|
||||
# Update the signal handler for the process. It overrides the sigquit handler in the launch phase.
|
||||
loop.add_signal_handler(
|
||||
signal.SIGQUIT, signal_handler.running_phase_sigquit_handler
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"Signal handler is not added because the grpc request manager is "
|
||||
"not in the main thread. This disables graceful shutdown of the "
|
||||
"grpc request manager when SIGTERM is received."
|
||||
)
|
||||
self.asyncio_tasks.add(
|
||||
loop.create_task(print_exception_wrapper(self.sigterm_watchdog))
|
||||
)
|
||||
|
||||
async def sigterm_watchdog(self):
|
||||
"""Watchdog to handle SIGTERM gracefully, matching TokenizerManager pattern."""
|
||||
while not self.gracefully_exit:
|
||||
await asyncio.sleep(1.0)
|
||||
|
||||
|
||||
async def print_exception_wrapper(func):
|
||||
"""
|
||||
Sometimes an asyncio function does not print exception.
|
||||
We do another wrapper to handle the exception.
|
||||
"""
|
||||
try:
|
||||
await func()
|
||||
except Exception:
|
||||
traceback = get_exception_traceback()
|
||||
logger.error(f"GrpcRequestManager hit an exception: {traceback}")
|
||||
if hasattr(func, "__self__") and isinstance(func.__self__, GrpcRequestManager):
|
||||
func.__self__.dump_requests_before_crash()
|
||||
kill_process_tree(os.getpid(), include_parent=True)
|
||||
sys.exit(1)
|
||||
680
python/sglang/srt/entrypoints/grpc_server.py
Normal file
680
python/sglang/srt/entrypoints/grpc_server.py
Normal file
@@ -0,0 +1,680 @@
|
||||
"""
|
||||
Standalone gRPC Server for SGLang - Fully separated from HTTP server.
|
||||
Uses GrpcRequestManager for orchestration without tokenization.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import logging
|
||||
import multiprocessing as mp
|
||||
import os
|
||||
import signal
|
||||
import time
|
||||
from concurrent import futures
|
||||
from typing import AsyncIterator, Dict, Optional, Tuple
|
||||
|
||||
import grpc
|
||||
from grpc_reflection.v1alpha import reflection
|
||||
|
||||
from sglang.srt.entrypoints.grpc_request_manager import GrpcRequestManager
|
||||
from sglang.srt.grpc import sglang_scheduler_pb2, sglang_scheduler_pb2_grpc
|
||||
from sglang.srt.managers.data_parallel_controller import (
|
||||
run_data_parallel_controller_process,
|
||||
)
|
||||
from sglang.srt.managers.io_struct import (
|
||||
TokenizedEmbeddingReqInput,
|
||||
TokenizedGenerateReqInput,
|
||||
)
|
||||
from sglang.srt.managers.scheduler import run_scheduler_process
|
||||
from sglang.srt.sampling.sampling_params import SamplingParams as SGLSamplingParams
|
||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
||||
from sglang.srt.utils import configure_logger, prepare_model_and_tokenizer
|
||||
from sglang.utils import get_exception_traceback
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
HEALTH_CHECK_TIMEOUT = int(os.getenv("SGLANG_HEALTH_CHECK_TIMEOUT", 20))
|
||||
|
||||
|
||||
def _launch_scheduler_process_only(
|
||||
server_args: ServerArgs,
|
||||
port_args: Optional[PortArgs] = None,
|
||||
) -> Tuple[Dict, PortArgs, list]:
|
||||
"""
|
||||
Launch only the scheduler process(es) without tokenizer/detokenizer.
|
||||
Returns scheduler info, port args, and list of scheduler processes.
|
||||
"""
|
||||
# Configure global environment
|
||||
configure_logger(server_args)
|
||||
server_args.check_server_args()
|
||||
|
||||
# Allocate ports for inter-process communications
|
||||
if port_args is None:
|
||||
port_args = PortArgs.init_new(server_args)
|
||||
logger.info(f"{server_args=}")
|
||||
|
||||
# Prepare model and tokenizer paths
|
||||
server_args.model_path, server_args.tokenizer_path = prepare_model_and_tokenizer(
|
||||
server_args.model_path, server_args.tokenizer_path
|
||||
)
|
||||
|
||||
scheduler_procs = []
|
||||
if server_args.dp_size == 1:
|
||||
memory_saver_adapter = TorchMemorySaverAdapter.create(
|
||||
enable=server_args.enable_memory_saver
|
||||
)
|
||||
scheduler_pipe_readers = []
|
||||
|
||||
nnodes_per_tp_group = max(server_args.nnodes // server_args.pp_size, 1)
|
||||
tp_size_per_node = server_args.tp_size // nnodes_per_tp_group
|
||||
tp_rank_range = range(
|
||||
tp_size_per_node * (server_args.node_rank % nnodes_per_tp_group),
|
||||
tp_size_per_node * (server_args.node_rank % nnodes_per_tp_group + 1),
|
||||
)
|
||||
|
||||
pp_size_per_node = max(server_args.pp_size // server_args.nnodes, 1)
|
||||
pp_rank_range = range(
|
||||
pp_size_per_node * (server_args.node_rank // nnodes_per_tp_group),
|
||||
pp_size_per_node * (server_args.node_rank // nnodes_per_tp_group + 1),
|
||||
)
|
||||
|
||||
for pp_rank in pp_rank_range:
|
||||
for tp_rank in tp_rank_range:
|
||||
reader, writer = mp.Pipe(duplex=False)
|
||||
gpu_id = (
|
||||
server_args.base_gpu_id
|
||||
+ ((pp_rank % pp_size_per_node) * tp_size_per_node)
|
||||
+ (tp_rank % tp_size_per_node) * server_args.gpu_id_step
|
||||
)
|
||||
moe_ep_rank = tp_rank // (server_args.tp_size // server_args.ep_size)
|
||||
proc = mp.Process(
|
||||
target=run_scheduler_process,
|
||||
args=(
|
||||
server_args,
|
||||
port_args,
|
||||
gpu_id,
|
||||
tp_rank,
|
||||
moe_ep_rank,
|
||||
pp_rank,
|
||||
None,
|
||||
writer,
|
||||
None,
|
||||
),
|
||||
)
|
||||
|
||||
with memory_saver_adapter.configure_subprocess():
|
||||
proc.start()
|
||||
scheduler_procs.append(proc)
|
||||
scheduler_pipe_readers.append(reader)
|
||||
else:
|
||||
# Launch the data parallel controller
|
||||
reader, writer = mp.Pipe(duplex=False)
|
||||
scheduler_pipe_readers = [reader]
|
||||
proc = mp.Process(
|
||||
target=run_data_parallel_controller_process,
|
||||
args=(server_args, port_args, writer),
|
||||
)
|
||||
proc.start()
|
||||
scheduler_procs.append(proc)
|
||||
|
||||
# TODO(CatherineSue): handle cases for multi-node
|
||||
|
||||
# Wait for all scheduler processes to be ready
|
||||
scheduler_infos = []
|
||||
for i, reader in enumerate(scheduler_pipe_readers):
|
||||
try:
|
||||
data = reader.recv()
|
||||
except EOFError:
|
||||
logger.error(
|
||||
f"Rank {i} scheduler is dead. Please check if there are relevant logs."
|
||||
)
|
||||
scheduler_procs[i].join()
|
||||
logger.error(f"Exit code: {scheduler_procs[i].exitcode}")
|
||||
raise RuntimeError(f"Failed to initialize scheduler rank {i}")
|
||||
|
||||
if data.get("status") != "ready":
|
||||
raise RuntimeError(
|
||||
f"Scheduler rank {i} initialization failed: {data.get('error', 'Unknown error')}"
|
||||
)
|
||||
scheduler_infos.append(data)
|
||||
|
||||
logger.info(
|
||||
f"All {len(scheduler_procs)} scheduler process(es) initialized successfully"
|
||||
)
|
||||
|
||||
# Return the first scheduler's info (they should all be the same)
|
||||
return scheduler_infos[0], port_args, scheduler_procs
|
||||
|
||||
|
||||
class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer):
|
||||
"""
|
||||
Standalone gRPC service implementation using GrpcRequestManager.
|
||||
Fully separated from HTTP server with its own process and no shared globals.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
request_manager: GrpcRequestManager,
|
||||
server_args: ServerArgs,
|
||||
model_info: Dict,
|
||||
):
|
||||
"""Initialize the standalone gRPC service."""
|
||||
self.request_manager = request_manager
|
||||
self.server_args = server_args
|
||||
self.model_info = model_info
|
||||
self.start_time = time.time()
|
||||
|
||||
# Start the request manager's event loop using auto_create_handle_loop
|
||||
self.request_manager.auto_create_handle_loop()
|
||||
|
||||
logger.info("Standalone gRPC scheduler service initialized")
|
||||
|
||||
async def Generate(
|
||||
self,
|
||||
request: sglang_scheduler_pb2.GenerateRequest,
|
||||
context: grpc.aio.ServicerContext,
|
||||
) -> AsyncIterator[sglang_scheduler_pb2.GenerateResponse]:
|
||||
"""Handle generation requests with streaming responses."""
|
||||
logger.info(f"Generation request: {request.request_id}")
|
||||
|
||||
try:
|
||||
# Convert gRPC request to internal format
|
||||
tokenized_req = self._convert_generate_request(request)
|
||||
|
||||
# Submit to request manager
|
||||
output_queue = await self.request_manager.generate_request(
|
||||
obj=tokenized_req,
|
||||
request_id=request.request_id,
|
||||
grpc_context=context,
|
||||
)
|
||||
|
||||
# Stream outputs
|
||||
while True:
|
||||
try:
|
||||
# Get output with timeout
|
||||
output = await asyncio.wait_for(output_queue.get(), timeout=4)
|
||||
|
||||
# Check for errors
|
||||
if "error" in output:
|
||||
yield sglang_scheduler_pb2.GenerateResponse(
|
||||
request_id=request.request_id,
|
||||
error=sglang_scheduler_pb2.GenerateError(
|
||||
message=output["error"],
|
||||
http_status_code=(
|
||||
"500" if "abort" not in output else "499"
|
||||
),
|
||||
),
|
||||
)
|
||||
break
|
||||
|
||||
# Check if finished
|
||||
if output.get("finished", False):
|
||||
# Send completion
|
||||
yield self._create_completion_response(
|
||||
request.request_id, output
|
||||
)
|
||||
break
|
||||
else:
|
||||
# Send chunk
|
||||
yield self._create_chunk_response(request.request_id, output)
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
# Check if context is still active
|
||||
if context.cancelled():
|
||||
# Abort the request
|
||||
await self.request_manager.abort_request(request.request_id)
|
||||
break
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Generate failed: {e}\n{get_exception_traceback()}")
|
||||
yield sglang_scheduler_pb2.GenerateResponse(
|
||||
request_id=request.request_id,
|
||||
error=sglang_scheduler_pb2.GenerateError(
|
||||
message=str(e),
|
||||
http_status_code="500",
|
||||
details=get_exception_traceback(),
|
||||
),
|
||||
)
|
||||
|
||||
async def Embed(
|
||||
self,
|
||||
request: sglang_scheduler_pb2.EmbedRequest,
|
||||
context: grpc.aio.ServicerContext,
|
||||
) -> sglang_scheduler_pb2.EmbedResponse:
|
||||
"""Handle embedding requests."""
|
||||
logger.info(f"Embedding request: {request.request_id}")
|
||||
|
||||
try:
|
||||
# Convert request
|
||||
tokenized_req = self._convert_embed_request(request)
|
||||
|
||||
# Submit to request manager
|
||||
future = await self.request_manager.embedding_request(
|
||||
obj=tokenized_req,
|
||||
request_id=request.request_id,
|
||||
)
|
||||
|
||||
# Wait for result
|
||||
result = await future
|
||||
|
||||
# Create response
|
||||
return sglang_scheduler_pb2.EmbedResponse(
|
||||
request_id=request.request_id,
|
||||
complete=sglang_scheduler_pb2.EmbedComplete(
|
||||
embedding=result["embedding"],
|
||||
prompt_tokens=result.get("prompt_tokens", 0),
|
||||
cached_tokens=0,
|
||||
embedding_dim=len(result["embedding"]),
|
||||
generation_time=time.time() - self.start_time,
|
||||
),
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Embed failed: {e}\n{get_exception_traceback()}")
|
||||
return sglang_scheduler_pb2.EmbedResponse(
|
||||
request_id=request.request_id,
|
||||
error=sglang_scheduler_pb2.EmbedError(
|
||||
message=str(e),
|
||||
code="INTERNAL_ERROR",
|
||||
details=get_exception_traceback(),
|
||||
),
|
||||
)
|
||||
|
||||
async def HealthCheck(
|
||||
self,
|
||||
request: sglang_scheduler_pb2.HealthCheckRequest,
|
||||
context: grpc.aio.ServicerContext,
|
||||
) -> sglang_scheduler_pb2.HealthCheckResponse:
|
||||
"""Health check by generating from client input."""
|
||||
try:
|
||||
# Check if request manager is shutting down
|
||||
if self.request_manager.gracefully_exit:
|
||||
return sglang_scheduler_pb2.HealthCheckResponse(
|
||||
healthy=False, message="Server shutting down"
|
||||
)
|
||||
|
||||
# Extract tokenized input from request
|
||||
if not request.HasField("tokenized"):
|
||||
return sglang_scheduler_pb2.HealthCheckResponse(
|
||||
healthy=False, message="Tokenized input required for health check"
|
||||
)
|
||||
|
||||
input_text = request.tokenized.original_text
|
||||
input_ids = list(request.tokenized.input_ids)
|
||||
|
||||
# Create health check request
|
||||
rid = f"HEALTH_CHECK_GRPC_{time.time()}"
|
||||
|
||||
health_request = TokenizedGenerateReqInput(
|
||||
rid=rid,
|
||||
input_text=input_text,
|
||||
input_ids=input_ids,
|
||||
sampling_params=SGLSamplingParams(max_new_tokens=1, temperature=0.0),
|
||||
stream=False,
|
||||
mm_inputs=None,
|
||||
return_logprob=False,
|
||||
logprob_start_len=-1,
|
||||
top_logprobs_num=0,
|
||||
token_ids_logprob=None,
|
||||
)
|
||||
|
||||
logger.info(f"Sending health check request to request manager...")
|
||||
|
||||
# Submit and wait for response
|
||||
output_queue = await self.request_manager.generate_request(
|
||||
health_request, request_id=rid
|
||||
)
|
||||
|
||||
try:
|
||||
# Wait for response with configurable timeout
|
||||
response = await asyncio.wait_for(
|
||||
output_queue.get(), timeout=HEALTH_CHECK_TIMEOUT
|
||||
)
|
||||
|
||||
# Clean up
|
||||
if rid in self.request_manager.rid_to_state:
|
||||
del self.request_manager.rid_to_state[rid]
|
||||
|
||||
return sglang_scheduler_pb2.HealthCheckResponse(
|
||||
healthy=True, message="Health check passed"
|
||||
)
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
# Clean up on timeout
|
||||
if rid in self.request_manager.rid_to_state:
|
||||
del self.request_manager.rid_to_state[rid]
|
||||
|
||||
return sglang_scheduler_pb2.HealthCheckResponse(
|
||||
healthy=False, message="Health check timeout"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Health check failed: {e}")
|
||||
return sglang_scheduler_pb2.HealthCheckResponse(
|
||||
healthy=False, message=f"Health check error: {str(e)}"
|
||||
)
|
||||
|
||||
async def Abort(
|
||||
self,
|
||||
request: sglang_scheduler_pb2.AbortRequest,
|
||||
context: grpc.aio.ServicerContext,
|
||||
) -> sglang_scheduler_pb2.AbortResponse:
|
||||
"""Abort an ongoing request."""
|
||||
logger.info(f"Aborting request: {request.request_id}")
|
||||
|
||||
try:
|
||||
success = await self.request_manager.abort_request(request.request_id)
|
||||
|
||||
return sglang_scheduler_pb2.AbortResponse(
|
||||
success=success,
|
||||
message=f"Request {request.request_id} {'aborted' if success else 'not found'}",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Abort failed: {e}")
|
||||
return sglang_scheduler_pb2.AbortResponse(
|
||||
success=False,
|
||||
message=str(e),
|
||||
)
|
||||
|
||||
# Helper methods for request/response conversion
|
||||
|
||||
def _convert_generate_request(
|
||||
self, grpc_req: sglang_scheduler_pb2.GenerateRequest
|
||||
) -> TokenizedGenerateReqInput:
|
||||
"""Convert gRPC GenerateRequest to internal format."""
|
||||
|
||||
# Extract tokenized input
|
||||
if not grpc_req.HasField("tokenized"):
|
||||
raise ValueError("Tokenized input must be provided")
|
||||
|
||||
input_text = grpc_req.tokenized.original_text
|
||||
input_ids = list(grpc_req.tokenized.input_ids)
|
||||
|
||||
# Convert sampling params
|
||||
sampling_params = self._convert_sampling_params(grpc_req.sampling_params)
|
||||
|
||||
# Create request
|
||||
return TokenizedGenerateReqInput(
|
||||
rid=grpc_req.request_id,
|
||||
input_text=input_text,
|
||||
input_ids=input_ids,
|
||||
mm_inputs=None, # TODO: implement mm support
|
||||
sampling_params=sampling_params,
|
||||
return_logprob=grpc_req.return_logprob,
|
||||
logprob_start_len=grpc_req.logprob_start_len or -1,
|
||||
top_logprobs_num=grpc_req.top_logprobs_num or 0,
|
||||
stream=True, # Always stream for gRPC
|
||||
lora_path=grpc_req.lora_id if grpc_req.lora_id else None,
|
||||
token_ids_logprob=(
|
||||
list(grpc_req.token_ids_logprob) if grpc_req.token_ids_logprob else None
|
||||
),
|
||||
)
|
||||
|
||||
def _convert_embed_request(
|
||||
self, grpc_req: sglang_scheduler_pb2.EmbedRequest
|
||||
) -> TokenizedEmbeddingReqInput:
|
||||
"""Convert gRPC EmbedRequest to internal format."""
|
||||
|
||||
# Extract tokenized input
|
||||
if not grpc_req.HasField("tokenized"):
|
||||
raise ValueError("Tokenized input must be provided")
|
||||
|
||||
input_text = grpc_req.tokenized.original_text
|
||||
input_ids = list(grpc_req.tokenized.input_ids)
|
||||
|
||||
return TokenizedEmbeddingReqInput(
|
||||
rid=grpc_req.request_id,
|
||||
input_text=input_text,
|
||||
input_ids=input_ids,
|
||||
)
|
||||
|
||||
def _convert_sampling_params(
|
||||
self, grpc_params: sglang_scheduler_pb2.SamplingParams
|
||||
) -> SGLSamplingParams:
|
||||
"""Convert gRPC SamplingParams to internal format."""
|
||||
|
||||
# Handle constraint types
|
||||
regex = None
|
||||
json_schema = None
|
||||
ebnf_grammar = None
|
||||
|
||||
if grpc_params.HasField("regex"):
|
||||
regex = grpc_params.regex
|
||||
elif grpc_params.HasField("json_schema"):
|
||||
json_schema = grpc_params.json_schema
|
||||
elif grpc_params.HasField("ebnf_grammar"):
|
||||
ebnf_grammar = grpc_params.ebnf_grammar
|
||||
|
||||
return SGLSamplingParams(
|
||||
temperature=grpc_params.temperature or 1.0,
|
||||
top_p=grpc_params.top_p or 1.0,
|
||||
top_k=grpc_params.top_k or -1,
|
||||
min_p=grpc_params.min_p or 0.0,
|
||||
frequency_penalty=grpc_params.frequency_penalty or 0.0,
|
||||
presence_penalty=grpc_params.presence_penalty or 0.0,
|
||||
repetition_penalty=grpc_params.repetition_penalty or 1.0,
|
||||
max_new_tokens=grpc_params.max_new_tokens or 128,
|
||||
min_new_tokens=grpc_params.min_new_tokens or 0,
|
||||
stop=list(grpc_params.stop) if grpc_params.stop else None,
|
||||
stop_token_ids=(
|
||||
list(grpc_params.stop_token_ids) if grpc_params.stop_token_ids else None
|
||||
),
|
||||
skip_special_tokens=grpc_params.skip_special_tokens,
|
||||
spaces_between_special_tokens=grpc_params.spaces_between_special_tokens,
|
||||
regex=regex,
|
||||
json_schema=json_schema,
|
||||
ebnf=ebnf_grammar,
|
||||
n=grpc_params.n or 1,
|
||||
ignore_eos=grpc_params.ignore_eos,
|
||||
)
|
||||
|
||||
def _create_chunk_response(
|
||||
self, request_id: str, output: Dict
|
||||
) -> sglang_scheduler_pb2.GenerateResponse:
|
||||
"""Create a streaming chunk response."""
|
||||
return sglang_scheduler_pb2.GenerateResponse(
|
||||
request_id=request_id,
|
||||
chunk=sglang_scheduler_pb2.GenerateStreamChunk(
|
||||
token_id=output["token_ids"][-1] if output.get("token_ids") else 0,
|
||||
text=output.get("text", ""),
|
||||
prompt_tokens=0,
|
||||
completion_tokens=len(output.get("token_ids", [])),
|
||||
cached_tokens=0,
|
||||
generation_time=time.time() - self.start_time,
|
||||
queue_time=0.0,
|
||||
),
|
||||
)
|
||||
|
||||
def _create_completion_response(
|
||||
self, request_id: str, output: Dict
|
||||
) -> sglang_scheduler_pb2.GenerateResponse:
|
||||
"""Create a completion response."""
|
||||
|
||||
# Determine finish reason
|
||||
finish_reason = sglang_scheduler_pb2.GenerateComplete.STOP
|
||||
meta_info = output.get("meta_info", {})
|
||||
if meta_info.get("finish_reason") == "length":
|
||||
finish_reason = sglang_scheduler_pb2.GenerateComplete.LENGTH
|
||||
elif meta_info.get("finish_reason") == "eos_token":
|
||||
finish_reason = sglang_scheduler_pb2.GenerateComplete.EOS_TOKEN
|
||||
|
||||
return sglang_scheduler_pb2.GenerateResponse(
|
||||
request_id=request_id,
|
||||
complete=sglang_scheduler_pb2.GenerateComplete(
|
||||
output_ids=output.get("token_ids", []),
|
||||
output_text=output.get("text", ""),
|
||||
finish_reason=finish_reason,
|
||||
),
|
||||
)
|
||||
|
||||
async def shutdown(self):
|
||||
"""Shutdown the service."""
|
||||
logger.info("Shutting down gRPC service")
|
||||
|
||||
# Shutdown request manager (handles its own tasks)
|
||||
await self.request_manager.shutdown()
|
||||
|
||||
|
||||
async def serve_grpc(
|
||||
server_args: ServerArgs,
|
||||
model_info: Optional[Dict] = None,
|
||||
):
|
||||
"""Start the standalone gRPC server with integrated scheduler."""
|
||||
|
||||
# Launch only the scheduler process(es) (no tokenizer/detokenizer needed for gRPC)
|
||||
logger.info("Launching scheduler process(es)...")
|
||||
scheduler_info, port_args, scheduler_procs = _launch_scheduler_process_only(
|
||||
server_args=server_args,
|
||||
)
|
||||
|
||||
# Update model info from scheduler info
|
||||
if model_info is None:
|
||||
model_info = {
|
||||
"model_name": server_args.model_path,
|
||||
"max_context_length": scheduler_info.get(
|
||||
"max_total_num_tokens", server_args.context_length or 8192
|
||||
),
|
||||
"vocab_size": scheduler_info.get("vocab_size", 128256),
|
||||
"supports_vision": scheduler_info.get("supports_vision", False),
|
||||
"model_type": scheduler_info.get("model_type", "transformer"),
|
||||
"max_req_input_len": scheduler_info.get("max_req_input_len", 8192),
|
||||
"eos_token_ids": scheduler_info.get("eos_token_ids", []),
|
||||
"pad_token_id": scheduler_info.get("pad_token_id", 0),
|
||||
"bos_token_id": scheduler_info.get("bos_token_id", 1),
|
||||
}
|
||||
|
||||
# Create request manager with the correct port args
|
||||
request_manager = GrpcRequestManager(
|
||||
server_args=server_args,
|
||||
port_args=port_args,
|
||||
)
|
||||
|
||||
# Create gRPC server
|
||||
server = grpc.aio.server(
|
||||
futures.ThreadPoolExecutor(max_workers=10),
|
||||
options=[
|
||||
("grpc.max_send_message_length", 1024 * 1024 * 256),
|
||||
("grpc.max_receive_message_length", 1024 * 1024 * 256),
|
||||
],
|
||||
)
|
||||
|
||||
# Add service
|
||||
servicer = SGLangSchedulerServicer(
|
||||
request_manager=request_manager,
|
||||
server_args=server_args,
|
||||
model_info=model_info,
|
||||
)
|
||||
sglang_scheduler_pb2_grpc.add_SglangSchedulerServicer_to_server(servicer, server)
|
||||
|
||||
# Enable reflection
|
||||
SERVICE_NAMES = (
|
||||
sglang_scheduler_pb2.DESCRIPTOR.services_by_name["SglangScheduler"].full_name,
|
||||
reflection.SERVICE_NAME,
|
||||
)
|
||||
reflection.enable_server_reflection(SERVICE_NAMES, server)
|
||||
|
||||
# Start server
|
||||
listen_addr = f"{server_args.host}:{server_args.port}"
|
||||
server.add_insecure_port(listen_addr)
|
||||
|
||||
logger.info(f"Starting standalone gRPC server on {listen_addr}")
|
||||
|
||||
await server.start()
|
||||
|
||||
# Handle shutdown signals
|
||||
loop = asyncio.get_running_loop()
|
||||
stop_event = asyncio.Event()
|
||||
|
||||
def signal_handler():
|
||||
logger.info("Received shutdown signal")
|
||||
stop_event.set()
|
||||
|
||||
for sig in (signal.SIGTERM, signal.SIGINT):
|
||||
loop.add_signal_handler(sig, signal_handler)
|
||||
|
||||
try:
|
||||
await stop_event.wait()
|
||||
finally:
|
||||
logger.info("Shutting down gRPC server")
|
||||
await servicer.shutdown()
|
||||
await server.stop(5.0)
|
||||
|
||||
# Terminate scheduler processes
|
||||
for i, proc in enumerate(scheduler_procs):
|
||||
if proc and proc.is_alive():
|
||||
logger.info(f"Terminating scheduler process {i}...")
|
||||
proc.terminate()
|
||||
proc.join(timeout=5.0)
|
||||
if proc.is_alive():
|
||||
logger.warning(f"Force killing scheduler process {i}...")
|
||||
proc.kill()
|
||||
proc.join()
|
||||
|
||||
|
||||
def main():
|
||||
"""Main entry point for standalone gRPC server."""
|
||||
# Fix CUDA multiprocessing issues - must be called before any CUDA operations
|
||||
mp.set_start_method("spawn", force=True)
|
||||
|
||||
parser = argparse.ArgumentParser(description="SGLang Standalone gRPC Server")
|
||||
|
||||
# Server arguments
|
||||
parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to bind to")
|
||||
parser.add_argument("--port", type=int, default=30000, help="gRPC server port")
|
||||
|
||||
# Model arguments
|
||||
parser.add_argument("--model-path", type=str, required=True, help="Model path")
|
||||
parser.add_argument("--tokenizer-path", type=str, help="Tokenizer path")
|
||||
parser.add_argument("--context-length", type=int, help="Context length")
|
||||
parser.add_argument("--tp-size", type=int, default=1, help="Tensor parallel size")
|
||||
parser.add_argument("--dp-size", type=int, default=1, help="Data parallel size")
|
||||
|
||||
# Runtime arguments
|
||||
parser.add_argument(
|
||||
"--max-running-requests", type=int, default=2048, help="Max concurrent requests"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-total-tokens", type=int, default=1000000, help="Max total tokens"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-prefill-tokens", type=int, default=16384, help="Max prefill tokens"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--attention-backend", type=str, default="flashinfer", help="Attention backend"
|
||||
)
|
||||
parser.add_argument("--lora-paths", type=str, help="LoRA adapter paths")
|
||||
|
||||
# Logging
|
||||
parser.add_argument("--log-level", type=str, default="INFO", help="Logging level")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Convert to ServerArgs with gRPC host and port
|
||||
server_args = ServerArgs(
|
||||
model_path=args.model_path,
|
||||
tokenizer_path=args.tokenizer_path or args.model_path,
|
||||
context_length=args.context_length,
|
||||
tp_size=args.tp_size,
|
||||
dp_size=args.dp_size,
|
||||
max_running_requests=args.max_running_requests,
|
||||
max_total_tokens=args.max_total_tokens,
|
||||
max_prefill_tokens=args.max_prefill_tokens,
|
||||
attention_backend=args.attention_backend,
|
||||
lora_paths=args.lora_paths.split(",") if args.lora_paths else None,
|
||||
log_level=args.log_level,
|
||||
# Override with gRPC server host and port
|
||||
host=args.host,
|
||||
port=args.port,
|
||||
)
|
||||
|
||||
# Run server
|
||||
asyncio.run(
|
||||
serve_grpc(
|
||||
server_args=server_args,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user