Implement Standalone gRPC Server for SGLang Python Scheduler (#10283)
This commit is contained in:
580
python/sglang/srt/entrypoints/grpc_request_manager.py
Normal file
580
python/sglang/srt/entrypoints/grpc_request_manager.py
Normal file
@@ -0,0 +1,580 @@
|
||||
"""
|
||||
gRPC Request Manager - Orchestrates request lifecycle without tokenization.
|
||||
Mimics TokenizerManager's state management and ZMQ communication patterns.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import dataclasses
|
||||
import logging
|
||||
import os
|
||||
import signal
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import grpc
|
||||
import zmq
|
||||
import zmq.asyncio
|
||||
|
||||
from sglang.srt.managers.io_struct import (
|
||||
AbortReq,
|
||||
BatchEmbeddingOut,
|
||||
BatchTokenIDOut,
|
||||
HealthCheckOutput,
|
||||
TokenizedEmbeddingReqInput,
|
||||
TokenizedGenerateReqInput,
|
||||
)
|
||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||
from sglang.srt.utils import get_zmq_socket, kill_process_tree
|
||||
from sglang.utils import get_exception_traceback
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class GrpcSignalHandler:
|
||||
"""Minimal signal handler for gRPC server - delegates real crash handling to scheduler."""
|
||||
|
||||
def __init__(self, grpc_manager):
|
||||
self.grpc_manager = grpc_manager
|
||||
|
||||
def sigterm_handler(self, signum=None, frame=None):
|
||||
"""Handle SIGTERM by gracefully shutting down gRPC server."""
|
||||
logger.warning(
|
||||
f"SIGTERM received. {signum=} {frame=}. Shutting down gRPC server..."
|
||||
)
|
||||
self.grpc_manager.gracefully_exit = True
|
||||
|
||||
def running_phase_sigquit_handler(self, signum=None, frame=None):
|
||||
"""Handle SIGQUIT from failed scheduler process."""
|
||||
logger.error(
|
||||
"Received SIGQUIT from scheduler process. Scheduler failed, shutting down gRPC server."
|
||||
)
|
||||
logger.info(
|
||||
"Note: Crash dumps are handled by the scheduler process, not the gRPC server."
|
||||
)
|
||||
# Just exit cleanly - the scheduler handles crash dumps
|
||||
kill_process_tree(os.getpid(), include_parent=True)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class GrpcReqState:
|
||||
"""State tracking for a gRPC request."""
|
||||
|
||||
# Request identification
|
||||
request_id: str
|
||||
grpc_context: Optional[grpc.aio.ServicerContext]
|
||||
|
||||
# Communication
|
||||
out_queue: asyncio.Queue
|
||||
finished: bool
|
||||
event: asyncio.Event
|
||||
obj: Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput]
|
||||
|
||||
# Metrics (same as TokenizerManager's ReqState)
|
||||
created_time: float
|
||||
finished_time: float = 0.0
|
||||
first_token_time: float = 0.0
|
||||
last_time: float = 0.0
|
||||
last_completion_tokens: int = 1
|
||||
|
||||
# Streaming state
|
||||
last_output_offset: int = 0
|
||||
stream_finished: bool = False
|
||||
|
||||
# Output accumulation
|
||||
text: str = ""
|
||||
output_ids: List[int] = dataclasses.field(default_factory=list)
|
||||
input_token_logprobs_val: List[float] = dataclasses.field(default_factory=list)
|
||||
input_token_logprobs_idx: List[int] = dataclasses.field(default_factory=list)
|
||||
output_token_logprobs_val: List[float] = dataclasses.field(default_factory=list)
|
||||
output_token_logprobs_idx: List[int] = dataclasses.field(default_factory=list)
|
||||
input_top_logprobs_val: List[List[float]] = dataclasses.field(default_factory=list)
|
||||
input_top_logprobs_idx: List[List[int]] = dataclasses.field(default_factory=list)
|
||||
output_top_logprobs_val: List[List[float]] = dataclasses.field(default_factory=list)
|
||||
output_top_logprobs_idx: List[List[int]] = dataclasses.field(default_factory=list)
|
||||
|
||||
# Session state
|
||||
session_id: Optional[str] = None
|
||||
is_session_request: bool = False
|
||||
|
||||
|
||||
class GrpcRequestManager:
|
||||
"""
|
||||
Manages gRPC request lifecycle, mimicking TokenizerManager's orchestration
|
||||
behaviors without tokenization.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
server_args: ServerArgs,
|
||||
port_args: PortArgs,
|
||||
):
|
||||
"""Initialize the gRPC request manager."""
|
||||
self.server_args = server_args
|
||||
self.port_args = port_args
|
||||
|
||||
# ZMQ Communication Setup (same pattern as TokenizerManager)
|
||||
context = zmq.asyncio.Context(2)
|
||||
|
||||
# Socket for receiving outputs from scheduler
|
||||
self.recv_from_scheduler = get_zmq_socket(
|
||||
context, zmq.PULL, port_args.detokenizer_ipc_name, bind=True
|
||||
)
|
||||
|
||||
# Socket for sending requests to scheduler
|
||||
self.send_to_scheduler = get_zmq_socket(
|
||||
context, zmq.PUSH, port_args.scheduler_input_ipc_name, bind=True
|
||||
)
|
||||
|
||||
# State Management (from TokenizerManager)
|
||||
self.rid_to_state: Dict[str, GrpcReqState] = {}
|
||||
self.asyncio_tasks: set = set()
|
||||
self.gracefully_exit = False
|
||||
self.no_create_loop = False
|
||||
self.event_loop = None
|
||||
|
||||
# Pause/Resume Control
|
||||
self.is_pause = False
|
||||
self.is_pause_cond = asyncio.Condition()
|
||||
|
||||
# Metrics
|
||||
self.request_counter = 0
|
||||
self.request_counter_lock = asyncio.Lock()
|
||||
self.last_receive_tstamp = time.time()
|
||||
|
||||
# Crash dump for debugging
|
||||
self.crash_dump_request_list = []
|
||||
self.crash_dump_performed = False
|
||||
|
||||
logger.info(
|
||||
f"GrpcRequestManager initialized with ZMQ IPC: "
|
||||
f"recv={port_args.detokenizer_ipc_name}, "
|
||||
f"send={port_args.scheduler_input_ipc_name}"
|
||||
)
|
||||
|
||||
async def generate_request(
|
||||
self,
|
||||
obj: TokenizedGenerateReqInput,
|
||||
request_id: Optional[str] = None,
|
||||
grpc_context: Optional[grpc.aio.ServicerContext] = None,
|
||||
) -> asyncio.Queue:
|
||||
"""
|
||||
Submit a generation request to the scheduler.
|
||||
Returns a queue for streaming outputs.
|
||||
"""
|
||||
# Generate request ID if not provided
|
||||
if request_id is None:
|
||||
async with self.request_counter_lock:
|
||||
request_id = f"grpc-{self.request_counter}"
|
||||
self.request_counter += 1
|
||||
|
||||
obj.rid = request_id
|
||||
|
||||
# TODO: support log_request
|
||||
|
||||
# Create request state
|
||||
state = GrpcReqState(
|
||||
request_id=request_id,
|
||||
grpc_context=grpc_context,
|
||||
out_queue=asyncio.Queue(),
|
||||
finished=False,
|
||||
event=asyncio.Event(),
|
||||
obj=obj,
|
||||
created_time=time.time(),
|
||||
)
|
||||
|
||||
# Track session if needed
|
||||
if hasattr(obj, "session_params") and obj.session_params:
|
||||
state.session_id = obj.session_params.session_id
|
||||
state.is_session_request = True
|
||||
|
||||
# Register state
|
||||
self.rid_to_state[request_id] = state
|
||||
self.record_request_for_crash_dump(obj)
|
||||
|
||||
# Send to scheduler via ZMQ
|
||||
try:
|
||||
await self._send_to_scheduler(obj)
|
||||
except Exception as e:
|
||||
# Clean up on failure
|
||||
del self.rid_to_state[request_id]
|
||||
raise RuntimeError(f"Failed to send request to scheduler: {e}")
|
||||
|
||||
return state.out_queue
|
||||
|
||||
async def embedding_request(
|
||||
self,
|
||||
obj: TokenizedEmbeddingReqInput,
|
||||
request_id: Optional[str] = None,
|
||||
) -> asyncio.Future:
|
||||
"""
|
||||
Submit an embedding request to the scheduler.
|
||||
Returns a future that will contain the embedding result.
|
||||
"""
|
||||
# Generate request ID if not provided
|
||||
if request_id is None:
|
||||
async with self.request_counter_lock:
|
||||
request_id = f"grpc-embed-{self.request_counter}"
|
||||
self.request_counter += 1
|
||||
|
||||
obj.rid = request_id
|
||||
|
||||
# Create request state
|
||||
state = GrpcReqState(
|
||||
request_id=request_id,
|
||||
grpc_context=None,
|
||||
out_queue=asyncio.Queue(),
|
||||
finished=False,
|
||||
event=asyncio.Event(),
|
||||
obj=obj,
|
||||
created_time=time.time(),
|
||||
)
|
||||
|
||||
# Register state
|
||||
self.rid_to_state[request_id] = state
|
||||
|
||||
# Create future for result
|
||||
future = asyncio.Future()
|
||||
|
||||
# Send to scheduler
|
||||
try:
|
||||
await self._send_to_scheduler(obj)
|
||||
except Exception as e:
|
||||
del self.rid_to_state[request_id]
|
||||
future.set_exception(e)
|
||||
return future
|
||||
|
||||
# Wait for result in background
|
||||
async def wait_for_result():
|
||||
try:
|
||||
# Wait for completion
|
||||
await state.event.wait()
|
||||
# Get result from queue
|
||||
result = await state.out_queue.get()
|
||||
future.set_result(result)
|
||||
except Exception as e:
|
||||
future.set_exception(e)
|
||||
finally:
|
||||
# Clean up
|
||||
if request_id in self.rid_to_state:
|
||||
del self.rid_to_state[request_id]
|
||||
|
||||
asyncio.create_task(wait_for_result())
|
||||
return future
|
||||
|
||||
async def abort_request(self, request_id: str) -> bool:
|
||||
"""Abort a running request."""
|
||||
if request_id not in self.rid_to_state:
|
||||
return False
|
||||
|
||||
# Send abort to scheduler
|
||||
abort_req = AbortReq(rid=request_id)
|
||||
try:
|
||||
await self._send_to_scheduler(abort_req)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send abort request: {e}")
|
||||
return False
|
||||
|
||||
# Mark as finished
|
||||
state = self.rid_to_state.get(request_id)
|
||||
if state:
|
||||
state.finished = True
|
||||
state.stream_finished = True
|
||||
state.event.set()
|
||||
|
||||
# Send abort notification to output queue
|
||||
await state.out_queue.put({"error": "Request aborted", "abort": True})
|
||||
|
||||
return True
|
||||
|
||||
async def pause_generation(self):
|
||||
"""Pause generation processing."""
|
||||
async with self.is_pause_cond:
|
||||
self.is_pause = True
|
||||
logger.info("Generation paused")
|
||||
|
||||
async def resume_generation(self):
|
||||
"""Resume generation processing."""
|
||||
async with self.is_pause_cond:
|
||||
self.is_pause = False
|
||||
self.is_pause_cond.notify_all()
|
||||
logger.info("Generation resumed")
|
||||
|
||||
async def handle_loop(self):
|
||||
"""
|
||||
Main event loop - processes outputs from scheduler.
|
||||
Mimics TokenizerManager's handle_loop.
|
||||
"""
|
||||
while not self.gracefully_exit:
|
||||
try:
|
||||
# Receive from scheduler
|
||||
recv_obj = await self.recv_from_scheduler.recv_pyobj()
|
||||
self.last_receive_tstamp = time.time()
|
||||
|
||||
# Check for pause
|
||||
async with self.is_pause_cond:
|
||||
while self.is_pause:
|
||||
await self.is_pause_cond.wait()
|
||||
|
||||
# Handle different output types
|
||||
if isinstance(recv_obj, BatchTokenIDOut):
|
||||
await self._handle_batch_output(recv_obj)
|
||||
elif isinstance(recv_obj, BatchEmbeddingOut):
|
||||
await self._handle_embedding_output(recv_obj)
|
||||
elif isinstance(recv_obj, HealthCheckOutput):
|
||||
await self._handle_health_check_output(recv_obj)
|
||||
else:
|
||||
logger.warning(f"Unknown output type: {type(recv_obj)}")
|
||||
|
||||
except zmq.error.Again:
|
||||
# Timeout, check if we should exit
|
||||
if self.gracefully_exit:
|
||||
break
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.error(f"Handle loop error: {e}\n{get_exception_traceback()}")
|
||||
if self.gracefully_exit:
|
||||
break
|
||||
|
||||
async def _handle_batch_output(self, batch_out: BatchTokenIDOut):
|
||||
"""Handle batch generation output from scheduler."""
|
||||
# Process each request in the batch
|
||||
for i, rid in enumerate(batch_out.rids):
|
||||
if rid not in self.rid_to_state:
|
||||
continue
|
||||
|
||||
state = self.rid_to_state[rid]
|
||||
|
||||
# Update metrics
|
||||
now = time.time()
|
||||
if state.first_token_time == 0.0:
|
||||
state.first_token_time = now
|
||||
state.last_time = now
|
||||
|
||||
# Extract output for this request
|
||||
output_data = {
|
||||
"request_id": rid,
|
||||
"text": batch_out.decoded_texts[i] if batch_out.decoded_texts else "",
|
||||
"token_ids": batch_out.output_ids[i] if batch_out.output_ids else [],
|
||||
"finished": batch_out.finished_reasons[i] is not None,
|
||||
"meta_info": {
|
||||
"prompt_tokens": (
|
||||
batch_out.prompt_tokens[i] if batch_out.prompt_tokens else 0
|
||||
),
|
||||
"completion_tokens": (
|
||||
batch_out.completion_tokens[i]
|
||||
if batch_out.completion_tokens
|
||||
else 0
|
||||
),
|
||||
"finish_reason": (
|
||||
str(batch_out.finished_reasons[i])
|
||||
if batch_out.finished_reasons[i]
|
||||
else None
|
||||
),
|
||||
},
|
||||
}
|
||||
|
||||
# Add logprobs if available
|
||||
if batch_out.output_token_logprobs_val and i < len(
|
||||
batch_out.output_token_logprobs_val
|
||||
):
|
||||
output_data["logprobs"] = {
|
||||
"tokens": batch_out.output_token_logprobs_val[i],
|
||||
"top_logprobs": (
|
||||
batch_out.output_top_logprobs_val[i]
|
||||
if batch_out.output_top_logprobs_val
|
||||
and i < len(batch_out.output_top_logprobs_val)
|
||||
else None
|
||||
),
|
||||
}
|
||||
|
||||
# Update state
|
||||
if output_data["text"]:
|
||||
state.text += output_data["text"][state.last_output_offset :]
|
||||
state.last_output_offset = len(output_data["text"])
|
||||
|
||||
if output_data["token_ids"]:
|
||||
state.output_ids.extend(output_data["token_ids"])
|
||||
|
||||
# Send to output queue
|
||||
await state.out_queue.put(output_data)
|
||||
|
||||
# Handle completion
|
||||
if output_data["finished"]:
|
||||
state.finished = True
|
||||
state.finished_time = now
|
||||
state.stream_finished = True
|
||||
state.event.set()
|
||||
|
||||
# Remove from tracking after a delay
|
||||
async def cleanup():
|
||||
await asyncio.sleep(5.0)
|
||||
if rid in self.rid_to_state:
|
||||
del self.rid_to_state[rid]
|
||||
|
||||
asyncio.create_task(cleanup())
|
||||
|
||||
async def _handle_embedding_output(self, batch_out: BatchEmbeddingOut):
|
||||
"""Handle batch embedding output from scheduler."""
|
||||
for i, rid in enumerate(batch_out.rids):
|
||||
if rid not in self.rid_to_state:
|
||||
continue
|
||||
|
||||
state = self.rid_to_state[rid]
|
||||
|
||||
# Create result
|
||||
result = {
|
||||
"request_id": rid,
|
||||
"embedding": batch_out.embeddings[i],
|
||||
"prompt_tokens": (
|
||||
batch_out.prompt_tokens[i] if batch_out.prompt_tokens else 0
|
||||
),
|
||||
"finish_reason": (
|
||||
batch_out.finish_reason[i] if batch_out.finish_reason else None
|
||||
),
|
||||
}
|
||||
|
||||
# Send result
|
||||
await state.out_queue.put(result)
|
||||
|
||||
# Mark as finished
|
||||
state.finished = True
|
||||
state.finished_time = time.time()
|
||||
state.event.set()
|
||||
|
||||
async def _handle_health_check_output(self, health_out: HealthCheckOutput):
|
||||
"""Handle health check output from scheduler."""
|
||||
rid = health_out.rid
|
||||
|
||||
if rid not in self.rid_to_state:
|
||||
logger.warning(f"Health check output for unknown request: {rid}")
|
||||
return
|
||||
|
||||
state = self.rid_to_state[rid]
|
||||
|
||||
# Create health check result
|
||||
result = {
|
||||
"request_id": rid,
|
||||
"healthy": True, # If we got a response, scheduler is healthy
|
||||
"output_text": (
|
||||
health_out.output_str if hasattr(health_out, "output_str") else ""
|
||||
),
|
||||
"finish_reason": (
|
||||
health_out.finish_reason
|
||||
if hasattr(health_out, "finish_reason")
|
||||
else "stop"
|
||||
),
|
||||
}
|
||||
|
||||
# Send result
|
||||
await state.out_queue.put(result)
|
||||
|
||||
# Mark as finished
|
||||
state.finished = True
|
||||
state.finished_time = time.time()
|
||||
state.event.set()
|
||||
|
||||
async def _send_to_scheduler(self, obj):
|
||||
"""Send an object to the scheduler via ZMQ."""
|
||||
try:
|
||||
self.send_to_scheduler.send_pyobj(obj)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send to scheduler: {e}")
|
||||
raise
|
||||
|
||||
def record_request_for_crash_dump(self, obj):
|
||||
"""Record request for potential crash dump."""
|
||||
if len(self.crash_dump_request_list) < 100:
|
||||
self.crash_dump_request_list.append(
|
||||
{
|
||||
"time": time.time(),
|
||||
"request_id": getattr(obj, "rid", "unknown"),
|
||||
"type": type(obj).__name__,
|
||||
}
|
||||
)
|
||||
|
||||
async def shutdown(self):
|
||||
"""Gracefully shutdown the request manager."""
|
||||
logger.info("Shutting down GrpcRequestManager")
|
||||
self.gracefully_exit = True
|
||||
|
||||
# Cancel all pending requests
|
||||
for rid, state in self.rid_to_state.items():
|
||||
if not state.finished:
|
||||
await state.out_queue.put(
|
||||
{"error": "Server shutting down", "shutdown": True}
|
||||
)
|
||||
state.finished = True
|
||||
state.event.set()
|
||||
|
||||
# Wait for tasks to complete
|
||||
if self.asyncio_tasks:
|
||||
await asyncio.gather(*list(self.asyncio_tasks), return_exceptions=True)
|
||||
|
||||
# Close ZMQ sockets
|
||||
self.recv_from_scheduler.close()
|
||||
self.send_to_scheduler.close()
|
||||
|
||||
logger.info("GrpcRequestManager shutdown complete")
|
||||
|
||||
def get_server_info(self) -> Dict[str, Any]:
|
||||
"""Get server information for health checks."""
|
||||
return {
|
||||
"active_requests": len(self.rid_to_state),
|
||||
"paused": self.is_pause,
|
||||
"last_receive_time": self.last_receive_tstamp,
|
||||
}
|
||||
|
||||
def auto_create_handle_loop(self):
|
||||
"""Automatically create and start the handle_loop task, matching TokenizerManager pattern."""
|
||||
if self.no_create_loop:
|
||||
return
|
||||
|
||||
self.no_create_loop = True
|
||||
loop = asyncio.get_event_loop()
|
||||
self.asyncio_tasks.add(
|
||||
loop.create_task(print_exception_wrapper(self.handle_loop))
|
||||
)
|
||||
|
||||
self.event_loop = loop
|
||||
|
||||
# We cannot add signal handler when the grpc manager is not in
|
||||
# the main thread due to the CPython limitation.
|
||||
if threading.current_thread() is threading.main_thread():
|
||||
signal_handler = GrpcSignalHandler(self)
|
||||
loop.add_signal_handler(signal.SIGTERM, signal_handler.sigterm_handler)
|
||||
# Update the signal handler for the process. It overrides the sigquit handler in the launch phase.
|
||||
loop.add_signal_handler(
|
||||
signal.SIGQUIT, signal_handler.running_phase_sigquit_handler
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"Signal handler is not added because the grpc request manager is "
|
||||
"not in the main thread. This disables graceful shutdown of the "
|
||||
"grpc request manager when SIGTERM is received."
|
||||
)
|
||||
self.asyncio_tasks.add(
|
||||
loop.create_task(print_exception_wrapper(self.sigterm_watchdog))
|
||||
)
|
||||
|
||||
async def sigterm_watchdog(self):
|
||||
"""Watchdog to handle SIGTERM gracefully, matching TokenizerManager pattern."""
|
||||
while not self.gracefully_exit:
|
||||
await asyncio.sleep(1.0)
|
||||
|
||||
|
||||
async def print_exception_wrapper(func):
|
||||
"""
|
||||
Sometimes an asyncio function does not print exception.
|
||||
We do another wrapper to handle the exception.
|
||||
"""
|
||||
try:
|
||||
await func()
|
||||
except Exception:
|
||||
traceback = get_exception_traceback()
|
||||
logger.error(f"GrpcRequestManager hit an exception: {traceback}")
|
||||
if hasattr(func, "__self__") and isinstance(func.__self__, GrpcRequestManager):
|
||||
func.__self__.dump_requests_before_crash()
|
||||
kill_process_tree(os.getpid(), include_parent=True)
|
||||
sys.exit(1)
|
||||
680
python/sglang/srt/entrypoints/grpc_server.py
Normal file
680
python/sglang/srt/entrypoints/grpc_server.py
Normal file
@@ -0,0 +1,680 @@
|
||||
"""
|
||||
Standalone gRPC Server for SGLang - Fully separated from HTTP server.
|
||||
Uses GrpcRequestManager for orchestration without tokenization.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import logging
|
||||
import multiprocessing as mp
|
||||
import os
|
||||
import signal
|
||||
import time
|
||||
from concurrent import futures
|
||||
from typing import AsyncIterator, Dict, Optional, Tuple
|
||||
|
||||
import grpc
|
||||
from grpc_reflection.v1alpha import reflection
|
||||
|
||||
from sglang.srt.entrypoints.grpc_request_manager import GrpcRequestManager
|
||||
from sglang.srt.grpc import sglang_scheduler_pb2, sglang_scheduler_pb2_grpc
|
||||
from sglang.srt.managers.data_parallel_controller import (
|
||||
run_data_parallel_controller_process,
|
||||
)
|
||||
from sglang.srt.managers.io_struct import (
|
||||
TokenizedEmbeddingReqInput,
|
||||
TokenizedGenerateReqInput,
|
||||
)
|
||||
from sglang.srt.managers.scheduler import run_scheduler_process
|
||||
from sglang.srt.sampling.sampling_params import SamplingParams as SGLSamplingParams
|
||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
||||
from sglang.srt.utils import configure_logger, prepare_model_and_tokenizer
|
||||
from sglang.utils import get_exception_traceback
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
HEALTH_CHECK_TIMEOUT = int(os.getenv("SGLANG_HEALTH_CHECK_TIMEOUT", 20))
|
||||
|
||||
|
||||
def _launch_scheduler_process_only(
|
||||
server_args: ServerArgs,
|
||||
port_args: Optional[PortArgs] = None,
|
||||
) -> Tuple[Dict, PortArgs, list]:
|
||||
"""
|
||||
Launch only the scheduler process(es) without tokenizer/detokenizer.
|
||||
Returns scheduler info, port args, and list of scheduler processes.
|
||||
"""
|
||||
# Configure global environment
|
||||
configure_logger(server_args)
|
||||
server_args.check_server_args()
|
||||
|
||||
# Allocate ports for inter-process communications
|
||||
if port_args is None:
|
||||
port_args = PortArgs.init_new(server_args)
|
||||
logger.info(f"{server_args=}")
|
||||
|
||||
# Prepare model and tokenizer paths
|
||||
server_args.model_path, server_args.tokenizer_path = prepare_model_and_tokenizer(
|
||||
server_args.model_path, server_args.tokenizer_path
|
||||
)
|
||||
|
||||
scheduler_procs = []
|
||||
if server_args.dp_size == 1:
|
||||
memory_saver_adapter = TorchMemorySaverAdapter.create(
|
||||
enable=server_args.enable_memory_saver
|
||||
)
|
||||
scheduler_pipe_readers = []
|
||||
|
||||
nnodes_per_tp_group = max(server_args.nnodes // server_args.pp_size, 1)
|
||||
tp_size_per_node = server_args.tp_size // nnodes_per_tp_group
|
||||
tp_rank_range = range(
|
||||
tp_size_per_node * (server_args.node_rank % nnodes_per_tp_group),
|
||||
tp_size_per_node * (server_args.node_rank % nnodes_per_tp_group + 1),
|
||||
)
|
||||
|
||||
pp_size_per_node = max(server_args.pp_size // server_args.nnodes, 1)
|
||||
pp_rank_range = range(
|
||||
pp_size_per_node * (server_args.node_rank // nnodes_per_tp_group),
|
||||
pp_size_per_node * (server_args.node_rank // nnodes_per_tp_group + 1),
|
||||
)
|
||||
|
||||
for pp_rank in pp_rank_range:
|
||||
for tp_rank in tp_rank_range:
|
||||
reader, writer = mp.Pipe(duplex=False)
|
||||
gpu_id = (
|
||||
server_args.base_gpu_id
|
||||
+ ((pp_rank % pp_size_per_node) * tp_size_per_node)
|
||||
+ (tp_rank % tp_size_per_node) * server_args.gpu_id_step
|
||||
)
|
||||
moe_ep_rank = tp_rank // (server_args.tp_size // server_args.ep_size)
|
||||
proc = mp.Process(
|
||||
target=run_scheduler_process,
|
||||
args=(
|
||||
server_args,
|
||||
port_args,
|
||||
gpu_id,
|
||||
tp_rank,
|
||||
moe_ep_rank,
|
||||
pp_rank,
|
||||
None,
|
||||
writer,
|
||||
None,
|
||||
),
|
||||
)
|
||||
|
||||
with memory_saver_adapter.configure_subprocess():
|
||||
proc.start()
|
||||
scheduler_procs.append(proc)
|
||||
scheduler_pipe_readers.append(reader)
|
||||
else:
|
||||
# Launch the data parallel controller
|
||||
reader, writer = mp.Pipe(duplex=False)
|
||||
scheduler_pipe_readers = [reader]
|
||||
proc = mp.Process(
|
||||
target=run_data_parallel_controller_process,
|
||||
args=(server_args, port_args, writer),
|
||||
)
|
||||
proc.start()
|
||||
scheduler_procs.append(proc)
|
||||
|
||||
# TODO(CatherineSue): handle cases for multi-node
|
||||
|
||||
# Wait for all scheduler processes to be ready
|
||||
scheduler_infos = []
|
||||
for i, reader in enumerate(scheduler_pipe_readers):
|
||||
try:
|
||||
data = reader.recv()
|
||||
except EOFError:
|
||||
logger.error(
|
||||
f"Rank {i} scheduler is dead. Please check if there are relevant logs."
|
||||
)
|
||||
scheduler_procs[i].join()
|
||||
logger.error(f"Exit code: {scheduler_procs[i].exitcode}")
|
||||
raise RuntimeError(f"Failed to initialize scheduler rank {i}")
|
||||
|
||||
if data.get("status") != "ready":
|
||||
raise RuntimeError(
|
||||
f"Scheduler rank {i} initialization failed: {data.get('error', 'Unknown error')}"
|
||||
)
|
||||
scheduler_infos.append(data)
|
||||
|
||||
logger.info(
|
||||
f"All {len(scheduler_procs)} scheduler process(es) initialized successfully"
|
||||
)
|
||||
|
||||
# Return the first scheduler's info (they should all be the same)
|
||||
return scheduler_infos[0], port_args, scheduler_procs
|
||||
|
||||
|
||||
class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer):
|
||||
"""
|
||||
Standalone gRPC service implementation using GrpcRequestManager.
|
||||
Fully separated from HTTP server with its own process and no shared globals.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
request_manager: GrpcRequestManager,
|
||||
server_args: ServerArgs,
|
||||
model_info: Dict,
|
||||
):
|
||||
"""Initialize the standalone gRPC service."""
|
||||
self.request_manager = request_manager
|
||||
self.server_args = server_args
|
||||
self.model_info = model_info
|
||||
self.start_time = time.time()
|
||||
|
||||
# Start the request manager's event loop using auto_create_handle_loop
|
||||
self.request_manager.auto_create_handle_loop()
|
||||
|
||||
logger.info("Standalone gRPC scheduler service initialized")
|
||||
|
||||
async def Generate(
|
||||
self,
|
||||
request: sglang_scheduler_pb2.GenerateRequest,
|
||||
context: grpc.aio.ServicerContext,
|
||||
) -> AsyncIterator[sglang_scheduler_pb2.GenerateResponse]:
|
||||
"""Handle generation requests with streaming responses."""
|
||||
logger.info(f"Generation request: {request.request_id}")
|
||||
|
||||
try:
|
||||
# Convert gRPC request to internal format
|
||||
tokenized_req = self._convert_generate_request(request)
|
||||
|
||||
# Submit to request manager
|
||||
output_queue = await self.request_manager.generate_request(
|
||||
obj=tokenized_req,
|
||||
request_id=request.request_id,
|
||||
grpc_context=context,
|
||||
)
|
||||
|
||||
# Stream outputs
|
||||
while True:
|
||||
try:
|
||||
# Get output with timeout
|
||||
output = await asyncio.wait_for(output_queue.get(), timeout=4)
|
||||
|
||||
# Check for errors
|
||||
if "error" in output:
|
||||
yield sglang_scheduler_pb2.GenerateResponse(
|
||||
request_id=request.request_id,
|
||||
error=sglang_scheduler_pb2.GenerateError(
|
||||
message=output["error"],
|
||||
http_status_code=(
|
||||
"500" if "abort" not in output else "499"
|
||||
),
|
||||
),
|
||||
)
|
||||
break
|
||||
|
||||
# Check if finished
|
||||
if output.get("finished", False):
|
||||
# Send completion
|
||||
yield self._create_completion_response(
|
||||
request.request_id, output
|
||||
)
|
||||
break
|
||||
else:
|
||||
# Send chunk
|
||||
yield self._create_chunk_response(request.request_id, output)
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
# Check if context is still active
|
||||
if context.cancelled():
|
||||
# Abort the request
|
||||
await self.request_manager.abort_request(request.request_id)
|
||||
break
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Generate failed: {e}\n{get_exception_traceback()}")
|
||||
yield sglang_scheduler_pb2.GenerateResponse(
|
||||
request_id=request.request_id,
|
||||
error=sglang_scheduler_pb2.GenerateError(
|
||||
message=str(e),
|
||||
http_status_code="500",
|
||||
details=get_exception_traceback(),
|
||||
),
|
||||
)
|
||||
|
||||
async def Embed(
|
||||
self,
|
||||
request: sglang_scheduler_pb2.EmbedRequest,
|
||||
context: grpc.aio.ServicerContext,
|
||||
) -> sglang_scheduler_pb2.EmbedResponse:
|
||||
"""Handle embedding requests."""
|
||||
logger.info(f"Embedding request: {request.request_id}")
|
||||
|
||||
try:
|
||||
# Convert request
|
||||
tokenized_req = self._convert_embed_request(request)
|
||||
|
||||
# Submit to request manager
|
||||
future = await self.request_manager.embedding_request(
|
||||
obj=tokenized_req,
|
||||
request_id=request.request_id,
|
||||
)
|
||||
|
||||
# Wait for result
|
||||
result = await future
|
||||
|
||||
# Create response
|
||||
return sglang_scheduler_pb2.EmbedResponse(
|
||||
request_id=request.request_id,
|
||||
complete=sglang_scheduler_pb2.EmbedComplete(
|
||||
embedding=result["embedding"],
|
||||
prompt_tokens=result.get("prompt_tokens", 0),
|
||||
cached_tokens=0,
|
||||
embedding_dim=len(result["embedding"]),
|
||||
generation_time=time.time() - self.start_time,
|
||||
),
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Embed failed: {e}\n{get_exception_traceback()}")
|
||||
return sglang_scheduler_pb2.EmbedResponse(
|
||||
request_id=request.request_id,
|
||||
error=sglang_scheduler_pb2.EmbedError(
|
||||
message=str(e),
|
||||
code="INTERNAL_ERROR",
|
||||
details=get_exception_traceback(),
|
||||
),
|
||||
)
|
||||
|
||||
async def HealthCheck(
|
||||
self,
|
||||
request: sglang_scheduler_pb2.HealthCheckRequest,
|
||||
context: grpc.aio.ServicerContext,
|
||||
) -> sglang_scheduler_pb2.HealthCheckResponse:
|
||||
"""Health check by generating from client input."""
|
||||
try:
|
||||
# Check if request manager is shutting down
|
||||
if self.request_manager.gracefully_exit:
|
||||
return sglang_scheduler_pb2.HealthCheckResponse(
|
||||
healthy=False, message="Server shutting down"
|
||||
)
|
||||
|
||||
# Extract tokenized input from request
|
||||
if not request.HasField("tokenized"):
|
||||
return sglang_scheduler_pb2.HealthCheckResponse(
|
||||
healthy=False, message="Tokenized input required for health check"
|
||||
)
|
||||
|
||||
input_text = request.tokenized.original_text
|
||||
input_ids = list(request.tokenized.input_ids)
|
||||
|
||||
# Create health check request
|
||||
rid = f"HEALTH_CHECK_GRPC_{time.time()}"
|
||||
|
||||
health_request = TokenizedGenerateReqInput(
|
||||
rid=rid,
|
||||
input_text=input_text,
|
||||
input_ids=input_ids,
|
||||
sampling_params=SGLSamplingParams(max_new_tokens=1, temperature=0.0),
|
||||
stream=False,
|
||||
mm_inputs=None,
|
||||
return_logprob=False,
|
||||
logprob_start_len=-1,
|
||||
top_logprobs_num=0,
|
||||
token_ids_logprob=None,
|
||||
)
|
||||
|
||||
logger.info(f"Sending health check request to request manager...")
|
||||
|
||||
# Submit and wait for response
|
||||
output_queue = await self.request_manager.generate_request(
|
||||
health_request, request_id=rid
|
||||
)
|
||||
|
||||
try:
|
||||
# Wait for response with configurable timeout
|
||||
response = await asyncio.wait_for(
|
||||
output_queue.get(), timeout=HEALTH_CHECK_TIMEOUT
|
||||
)
|
||||
|
||||
# Clean up
|
||||
if rid in self.request_manager.rid_to_state:
|
||||
del self.request_manager.rid_to_state[rid]
|
||||
|
||||
return sglang_scheduler_pb2.HealthCheckResponse(
|
||||
healthy=True, message="Health check passed"
|
||||
)
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
# Clean up on timeout
|
||||
if rid in self.request_manager.rid_to_state:
|
||||
del self.request_manager.rid_to_state[rid]
|
||||
|
||||
return sglang_scheduler_pb2.HealthCheckResponse(
|
||||
healthy=False, message="Health check timeout"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Health check failed: {e}")
|
||||
return sglang_scheduler_pb2.HealthCheckResponse(
|
||||
healthy=False, message=f"Health check error: {str(e)}"
|
||||
)
|
||||
|
||||
async def Abort(
|
||||
self,
|
||||
request: sglang_scheduler_pb2.AbortRequest,
|
||||
context: grpc.aio.ServicerContext,
|
||||
) -> sglang_scheduler_pb2.AbortResponse:
|
||||
"""Abort an ongoing request."""
|
||||
logger.info(f"Aborting request: {request.request_id}")
|
||||
|
||||
try:
|
||||
success = await self.request_manager.abort_request(request.request_id)
|
||||
|
||||
return sglang_scheduler_pb2.AbortResponse(
|
||||
success=success,
|
||||
message=f"Request {request.request_id} {'aborted' if success else 'not found'}",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Abort failed: {e}")
|
||||
return sglang_scheduler_pb2.AbortResponse(
|
||||
success=False,
|
||||
message=str(e),
|
||||
)
|
||||
|
||||
# Helper methods for request/response conversion
|
||||
|
||||
def _convert_generate_request(
|
||||
self, grpc_req: sglang_scheduler_pb2.GenerateRequest
|
||||
) -> TokenizedGenerateReqInput:
|
||||
"""Convert gRPC GenerateRequest to internal format."""
|
||||
|
||||
# Extract tokenized input
|
||||
if not grpc_req.HasField("tokenized"):
|
||||
raise ValueError("Tokenized input must be provided")
|
||||
|
||||
input_text = grpc_req.tokenized.original_text
|
||||
input_ids = list(grpc_req.tokenized.input_ids)
|
||||
|
||||
# Convert sampling params
|
||||
sampling_params = self._convert_sampling_params(grpc_req.sampling_params)
|
||||
|
||||
# Create request
|
||||
return TokenizedGenerateReqInput(
|
||||
rid=grpc_req.request_id,
|
||||
input_text=input_text,
|
||||
input_ids=input_ids,
|
||||
mm_inputs=None, # TODO: implement mm support
|
||||
sampling_params=sampling_params,
|
||||
return_logprob=grpc_req.return_logprob,
|
||||
logprob_start_len=grpc_req.logprob_start_len or -1,
|
||||
top_logprobs_num=grpc_req.top_logprobs_num or 0,
|
||||
stream=True, # Always stream for gRPC
|
||||
lora_path=grpc_req.lora_id if grpc_req.lora_id else None,
|
||||
token_ids_logprob=(
|
||||
list(grpc_req.token_ids_logprob) if grpc_req.token_ids_logprob else None
|
||||
),
|
||||
)
|
||||
|
||||
def _convert_embed_request(
|
||||
self, grpc_req: sglang_scheduler_pb2.EmbedRequest
|
||||
) -> TokenizedEmbeddingReqInput:
|
||||
"""Convert gRPC EmbedRequest to internal format."""
|
||||
|
||||
# Extract tokenized input
|
||||
if not grpc_req.HasField("tokenized"):
|
||||
raise ValueError("Tokenized input must be provided")
|
||||
|
||||
input_text = grpc_req.tokenized.original_text
|
||||
input_ids = list(grpc_req.tokenized.input_ids)
|
||||
|
||||
return TokenizedEmbeddingReqInput(
|
||||
rid=grpc_req.request_id,
|
||||
input_text=input_text,
|
||||
input_ids=input_ids,
|
||||
)
|
||||
|
||||
def _convert_sampling_params(
|
||||
self, grpc_params: sglang_scheduler_pb2.SamplingParams
|
||||
) -> SGLSamplingParams:
|
||||
"""Convert gRPC SamplingParams to internal format."""
|
||||
|
||||
# Handle constraint types
|
||||
regex = None
|
||||
json_schema = None
|
||||
ebnf_grammar = None
|
||||
|
||||
if grpc_params.HasField("regex"):
|
||||
regex = grpc_params.regex
|
||||
elif grpc_params.HasField("json_schema"):
|
||||
json_schema = grpc_params.json_schema
|
||||
elif grpc_params.HasField("ebnf_grammar"):
|
||||
ebnf_grammar = grpc_params.ebnf_grammar
|
||||
|
||||
return SGLSamplingParams(
|
||||
temperature=grpc_params.temperature or 1.0,
|
||||
top_p=grpc_params.top_p or 1.0,
|
||||
top_k=grpc_params.top_k or -1,
|
||||
min_p=grpc_params.min_p or 0.0,
|
||||
frequency_penalty=grpc_params.frequency_penalty or 0.0,
|
||||
presence_penalty=grpc_params.presence_penalty or 0.0,
|
||||
repetition_penalty=grpc_params.repetition_penalty or 1.0,
|
||||
max_new_tokens=grpc_params.max_new_tokens or 128,
|
||||
min_new_tokens=grpc_params.min_new_tokens or 0,
|
||||
stop=list(grpc_params.stop) if grpc_params.stop else None,
|
||||
stop_token_ids=(
|
||||
list(grpc_params.stop_token_ids) if grpc_params.stop_token_ids else None
|
||||
),
|
||||
skip_special_tokens=grpc_params.skip_special_tokens,
|
||||
spaces_between_special_tokens=grpc_params.spaces_between_special_tokens,
|
||||
regex=regex,
|
||||
json_schema=json_schema,
|
||||
ebnf=ebnf_grammar,
|
||||
n=grpc_params.n or 1,
|
||||
ignore_eos=grpc_params.ignore_eos,
|
||||
)
|
||||
|
||||
def _create_chunk_response(
|
||||
self, request_id: str, output: Dict
|
||||
) -> sglang_scheduler_pb2.GenerateResponse:
|
||||
"""Create a streaming chunk response."""
|
||||
return sglang_scheduler_pb2.GenerateResponse(
|
||||
request_id=request_id,
|
||||
chunk=sglang_scheduler_pb2.GenerateStreamChunk(
|
||||
token_id=output["token_ids"][-1] if output.get("token_ids") else 0,
|
||||
text=output.get("text", ""),
|
||||
prompt_tokens=0,
|
||||
completion_tokens=len(output.get("token_ids", [])),
|
||||
cached_tokens=0,
|
||||
generation_time=time.time() - self.start_time,
|
||||
queue_time=0.0,
|
||||
),
|
||||
)
|
||||
|
||||
def _create_completion_response(
|
||||
self, request_id: str, output: Dict
|
||||
) -> sglang_scheduler_pb2.GenerateResponse:
|
||||
"""Create a completion response."""
|
||||
|
||||
# Determine finish reason
|
||||
finish_reason = sglang_scheduler_pb2.GenerateComplete.STOP
|
||||
meta_info = output.get("meta_info", {})
|
||||
if meta_info.get("finish_reason") == "length":
|
||||
finish_reason = sglang_scheduler_pb2.GenerateComplete.LENGTH
|
||||
elif meta_info.get("finish_reason") == "eos_token":
|
||||
finish_reason = sglang_scheduler_pb2.GenerateComplete.EOS_TOKEN
|
||||
|
||||
return sglang_scheduler_pb2.GenerateResponse(
|
||||
request_id=request_id,
|
||||
complete=sglang_scheduler_pb2.GenerateComplete(
|
||||
output_ids=output.get("token_ids", []),
|
||||
output_text=output.get("text", ""),
|
||||
finish_reason=finish_reason,
|
||||
),
|
||||
)
|
||||
|
||||
async def shutdown(self):
|
||||
"""Shutdown the service."""
|
||||
logger.info("Shutting down gRPC service")
|
||||
|
||||
# Shutdown request manager (handles its own tasks)
|
||||
await self.request_manager.shutdown()
|
||||
|
||||
|
||||
async def serve_grpc(
|
||||
server_args: ServerArgs,
|
||||
model_info: Optional[Dict] = None,
|
||||
):
|
||||
"""Start the standalone gRPC server with integrated scheduler."""
|
||||
|
||||
# Launch only the scheduler process(es) (no tokenizer/detokenizer needed for gRPC)
|
||||
logger.info("Launching scheduler process(es)...")
|
||||
scheduler_info, port_args, scheduler_procs = _launch_scheduler_process_only(
|
||||
server_args=server_args,
|
||||
)
|
||||
|
||||
# Update model info from scheduler info
|
||||
if model_info is None:
|
||||
model_info = {
|
||||
"model_name": server_args.model_path,
|
||||
"max_context_length": scheduler_info.get(
|
||||
"max_total_num_tokens", server_args.context_length or 8192
|
||||
),
|
||||
"vocab_size": scheduler_info.get("vocab_size", 128256),
|
||||
"supports_vision": scheduler_info.get("supports_vision", False),
|
||||
"model_type": scheduler_info.get("model_type", "transformer"),
|
||||
"max_req_input_len": scheduler_info.get("max_req_input_len", 8192),
|
||||
"eos_token_ids": scheduler_info.get("eos_token_ids", []),
|
||||
"pad_token_id": scheduler_info.get("pad_token_id", 0),
|
||||
"bos_token_id": scheduler_info.get("bos_token_id", 1),
|
||||
}
|
||||
|
||||
# Create request manager with the correct port args
|
||||
request_manager = GrpcRequestManager(
|
||||
server_args=server_args,
|
||||
port_args=port_args,
|
||||
)
|
||||
|
||||
# Create gRPC server
|
||||
server = grpc.aio.server(
|
||||
futures.ThreadPoolExecutor(max_workers=10),
|
||||
options=[
|
||||
("grpc.max_send_message_length", 1024 * 1024 * 256),
|
||||
("grpc.max_receive_message_length", 1024 * 1024 * 256),
|
||||
],
|
||||
)
|
||||
|
||||
# Add service
|
||||
servicer = SGLangSchedulerServicer(
|
||||
request_manager=request_manager,
|
||||
server_args=server_args,
|
||||
model_info=model_info,
|
||||
)
|
||||
sglang_scheduler_pb2_grpc.add_SglangSchedulerServicer_to_server(servicer, server)
|
||||
|
||||
# Enable reflection
|
||||
SERVICE_NAMES = (
|
||||
sglang_scheduler_pb2.DESCRIPTOR.services_by_name["SglangScheduler"].full_name,
|
||||
reflection.SERVICE_NAME,
|
||||
)
|
||||
reflection.enable_server_reflection(SERVICE_NAMES, server)
|
||||
|
||||
# Start server
|
||||
listen_addr = f"{server_args.host}:{server_args.port}"
|
||||
server.add_insecure_port(listen_addr)
|
||||
|
||||
logger.info(f"Starting standalone gRPC server on {listen_addr}")
|
||||
|
||||
await server.start()
|
||||
|
||||
# Handle shutdown signals
|
||||
loop = asyncio.get_running_loop()
|
||||
stop_event = asyncio.Event()
|
||||
|
||||
def signal_handler():
|
||||
logger.info("Received shutdown signal")
|
||||
stop_event.set()
|
||||
|
||||
for sig in (signal.SIGTERM, signal.SIGINT):
|
||||
loop.add_signal_handler(sig, signal_handler)
|
||||
|
||||
try:
|
||||
await stop_event.wait()
|
||||
finally:
|
||||
logger.info("Shutting down gRPC server")
|
||||
await servicer.shutdown()
|
||||
await server.stop(5.0)
|
||||
|
||||
# Terminate scheduler processes
|
||||
for i, proc in enumerate(scheduler_procs):
|
||||
if proc and proc.is_alive():
|
||||
logger.info(f"Terminating scheduler process {i}...")
|
||||
proc.terminate()
|
||||
proc.join(timeout=5.0)
|
||||
if proc.is_alive():
|
||||
logger.warning(f"Force killing scheduler process {i}...")
|
||||
proc.kill()
|
||||
proc.join()
|
||||
|
||||
|
||||
def main():
|
||||
"""Main entry point for standalone gRPC server."""
|
||||
# Fix CUDA multiprocessing issues - must be called before any CUDA operations
|
||||
mp.set_start_method("spawn", force=True)
|
||||
|
||||
parser = argparse.ArgumentParser(description="SGLang Standalone gRPC Server")
|
||||
|
||||
# Server arguments
|
||||
parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to bind to")
|
||||
parser.add_argument("--port", type=int, default=30000, help="gRPC server port")
|
||||
|
||||
# Model arguments
|
||||
parser.add_argument("--model-path", type=str, required=True, help="Model path")
|
||||
parser.add_argument("--tokenizer-path", type=str, help="Tokenizer path")
|
||||
parser.add_argument("--context-length", type=int, help="Context length")
|
||||
parser.add_argument("--tp-size", type=int, default=1, help="Tensor parallel size")
|
||||
parser.add_argument("--dp-size", type=int, default=1, help="Data parallel size")
|
||||
|
||||
# Runtime arguments
|
||||
parser.add_argument(
|
||||
"--max-running-requests", type=int, default=2048, help="Max concurrent requests"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-total-tokens", type=int, default=1000000, help="Max total tokens"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-prefill-tokens", type=int, default=16384, help="Max prefill tokens"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--attention-backend", type=str, default="flashinfer", help="Attention backend"
|
||||
)
|
||||
parser.add_argument("--lora-paths", type=str, help="LoRA adapter paths")
|
||||
|
||||
# Logging
|
||||
parser.add_argument("--log-level", type=str, default="INFO", help="Logging level")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Convert to ServerArgs with gRPC host and port
|
||||
server_args = ServerArgs(
|
||||
model_path=args.model_path,
|
||||
tokenizer_path=args.tokenizer_path or args.model_path,
|
||||
context_length=args.context_length,
|
||||
tp_size=args.tp_size,
|
||||
dp_size=args.dp_size,
|
||||
max_running_requests=args.max_running_requests,
|
||||
max_total_tokens=args.max_total_tokens,
|
||||
max_prefill_tokens=args.max_prefill_tokens,
|
||||
attention_backend=args.attention_backend,
|
||||
lora_paths=args.lora_paths.split(",") if args.lora_paths else None,
|
||||
log_level=args.log_level,
|
||||
# Override with gRPC server host and port
|
||||
host=args.host,
|
||||
port=args.port,
|
||||
)
|
||||
|
||||
# Run server
|
||||
asyncio.run(
|
||||
serve_grpc(
|
||||
server_args=server_args,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
1
python/sglang/srt/grpc/__init__.py
Normal file
1
python/sglang/srt/grpc/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# SGLang gRPC module
|
||||
389
python/sglang/srt/grpc/sglang_scheduler.proto
Normal file
389
python/sglang/srt/grpc/sglang_scheduler.proto
Normal file
@@ -0,0 +1,389 @@
|
||||
syntax = "proto3";
|
||||
|
||||
package sglang.grpc.scheduler;
|
||||
|
||||
import "google/protobuf/timestamp.proto";
|
||||
import "google/protobuf/struct.proto";
|
||||
|
||||
// Service definition for SGLang scheduler communication
|
||||
// This protocol bridges the Rust router and Python scheduler
|
||||
service SglangScheduler {
|
||||
// Submit a generation request (supports streaming)
|
||||
rpc Generate(GenerateRequest) returns (stream GenerateResponse);
|
||||
|
||||
// Submit an embedding request
|
||||
rpc Embed(EmbedRequest) returns (EmbedResponse);
|
||||
|
||||
// Health check and metrics
|
||||
rpc HealthCheck(HealthCheckRequest) returns (HealthCheckResponse);
|
||||
|
||||
// Abort a running request
|
||||
rpc Abort(AbortRequest) returns (AbortResponse);
|
||||
|
||||
}
|
||||
|
||||
// =====================
|
||||
// Common Types
|
||||
// =====================
|
||||
|
||||
// Sampling parameters matching SGLang's SamplingParams
|
||||
message SamplingParams {
|
||||
float temperature = 1;
|
||||
float top_p = 2;
|
||||
int32 top_k = 3;
|
||||
float min_p = 4;
|
||||
float frequency_penalty = 5;
|
||||
float presence_penalty = 6;
|
||||
float repetition_penalty = 7;
|
||||
|
||||
int32 max_new_tokens = 8;
|
||||
repeated string stop = 9;
|
||||
repeated int32 stop_token_ids = 10;
|
||||
bool skip_special_tokens = 11;
|
||||
bool spaces_between_special_tokens = 12;
|
||||
|
||||
// Structured generation
|
||||
oneof constraint {
|
||||
string regex = 13;
|
||||
string json_schema = 14;
|
||||
string ebnf_grammar = 15;
|
||||
}
|
||||
|
||||
// LoRA adapter
|
||||
string lora_path = 16;
|
||||
|
||||
// Speculative decoding
|
||||
int32 n = 17; // Number of samples
|
||||
|
||||
// Token healing
|
||||
bool token_healing = 18;
|
||||
|
||||
// Additional parameters
|
||||
int32 min_new_tokens = 19;
|
||||
bool ignore_eos = 20;
|
||||
bool no_stop_trim = 21;
|
||||
int32 stream_interval = 22;
|
||||
map<string, float> logit_bias = 23;
|
||||
string structural_tag = 24;
|
||||
|
||||
// Custom parameters for extensibility
|
||||
google.protobuf.Struct custom_params = 25;
|
||||
}
|
||||
|
||||
|
||||
// Disaggregated serving parameters
|
||||
message DisaggregatedParams {
|
||||
string bootstrap_host = 1;
|
||||
int32 bootstrap_port = 2;
|
||||
int32 bootstrap_room = 3;
|
||||
}
|
||||
|
||||
// =====================
|
||||
// Generate Request
|
||||
// =====================
|
||||
|
||||
message GenerateRequest {
|
||||
string request_id = 1;
|
||||
|
||||
// Input must be tokenized (no raw text)
|
||||
TokenizedInput tokenized = 2;
|
||||
|
||||
// Multimodal inputs
|
||||
MultimodalInputs mm_inputs = 3;
|
||||
|
||||
// Generation parameters
|
||||
SamplingParams sampling_params = 4;
|
||||
|
||||
// Return options
|
||||
bool return_logprob = 5;
|
||||
int32 logprob_start_len = 6;
|
||||
int32 top_logprobs_num = 7;
|
||||
repeated int32 token_ids_logprob = 8;
|
||||
bool return_hidden_states = 9;
|
||||
|
||||
// For disaggregated serving
|
||||
DisaggregatedParams disaggregated_params = 10;
|
||||
|
||||
// Custom logit processor (serialized)
|
||||
string custom_logit_processor = 11;
|
||||
|
||||
// Request metadata
|
||||
google.protobuf.Timestamp timestamp = 12;
|
||||
bool log_metrics = 13;
|
||||
|
||||
// Input embeddings (alternative to text/tokens)
|
||||
repeated float input_embeds = 14;
|
||||
|
||||
// LoRA adapter ID (if pre-loaded)
|
||||
string lora_id = 15;
|
||||
|
||||
// Data parallel routing
|
||||
int32 data_parallel_rank = 16;
|
||||
|
||||
// For load balancing
|
||||
int32 dp_balance_id = 17;
|
||||
}
|
||||
|
||||
message TokenizedInput {
|
||||
string original_text = 1; // For reference
|
||||
repeated int32 input_ids = 2;
|
||||
}
|
||||
|
||||
message MultimodalInputs {
|
||||
// Simplified multimodal handling - actual data processed by tokenizer
|
||||
repeated string image_urls = 1;
|
||||
repeated string video_urls = 2;
|
||||
repeated string audio_urls = 3;
|
||||
|
||||
// Pre-processed multimodal features (if available)
|
||||
google.protobuf.Struct processed_features = 4;
|
||||
|
||||
// Raw data for direct processing
|
||||
repeated bytes image_data = 5;
|
||||
repeated bytes video_data = 6;
|
||||
repeated bytes audio_data = 7;
|
||||
|
||||
// Modality metadata
|
||||
repeated string modalities = 8;
|
||||
}
|
||||
|
||||
// =====================
|
||||
// Generate Response
|
||||
// =====================
|
||||
|
||||
message GenerateResponse {
|
||||
string request_id = 1;
|
||||
|
||||
// Response type
|
||||
oneof response {
|
||||
GenerateStreamChunk chunk = 2;
|
||||
GenerateComplete complete = 3;
|
||||
GenerateError error = 4;
|
||||
}
|
||||
}
|
||||
|
||||
message GenerateStreamChunk {
|
||||
// Generated token
|
||||
int32 token_id = 1;
|
||||
string text = 2;
|
||||
|
||||
// Cumulative counts
|
||||
int32 prompt_tokens = 3;
|
||||
int32 completion_tokens = 4;
|
||||
int32 cached_tokens = 5;
|
||||
|
||||
// Logprobs (if requested)
|
||||
LogProbs logprobs = 6;
|
||||
|
||||
// Hidden states (if requested)
|
||||
repeated float hidden_states = 7;
|
||||
|
||||
// Metadata
|
||||
float generation_time = 8; // Time to generate this token
|
||||
int32 queue_time = 9; // Time spent in queue
|
||||
}
|
||||
|
||||
message GenerateComplete {
|
||||
// Final output
|
||||
repeated int32 output_ids = 1;
|
||||
string output_text = 2;
|
||||
|
||||
// Finish reason
|
||||
enum FinishReason {
|
||||
// The model generated a stop sequence.
|
||||
STOP = 0;
|
||||
// The model reached the maximum generation length.
|
||||
LENGTH = 1;
|
||||
// The model generated an end-of-sequence (EOS) token.
|
||||
EOS_TOKEN = 2;
|
||||
// The model generated a user-provided stop string.
|
||||
STOP_STR = 3;
|
||||
// The request was aborted by the user or system.
|
||||
ABORT = 4;
|
||||
}
|
||||
FinishReason finish_reason = 3;
|
||||
|
||||
// All logprobs if requested
|
||||
repeated LogProbs all_logprobs = 11;
|
||||
|
||||
// All hidden states if requested
|
||||
repeated HiddenStates all_hidden_states = 12;
|
||||
}
|
||||
|
||||
message GenerateError {
|
||||
string message = 1;
|
||||
string http_status_code = 2;
|
||||
string details = 3;
|
||||
}
|
||||
|
||||
message LogProbs {
|
||||
repeated float token_logprobs = 1;
|
||||
repeated int32 token_ids = 2;
|
||||
|
||||
// Top logprobs at each position
|
||||
repeated TopLogProbs top_logprobs = 3;
|
||||
|
||||
// Decoded text for tokens
|
||||
repeated string token_texts = 4;
|
||||
}
|
||||
|
||||
message TopLogProbs {
|
||||
repeated float values = 1;
|
||||
repeated int32 token_ids = 2;
|
||||
repeated string token_texts = 3;
|
||||
}
|
||||
|
||||
message HiddenStates {
|
||||
repeated float values = 1;
|
||||
int32 layer = 2;
|
||||
int32 position = 3;
|
||||
}
|
||||
|
||||
// =====================
|
||||
// Embedding Request
|
||||
// =====================
|
||||
|
||||
message EmbedRequest {
|
||||
string request_id = 1;
|
||||
|
||||
// Input must be tokenized (no raw text)
|
||||
TokenizedInput tokenized = 2;
|
||||
|
||||
// Multimodal inputs
|
||||
MultimodalInputs mm_inputs = 4;
|
||||
|
||||
// Dummy sampling params for compatibility
|
||||
// EmbedRequest doesn't use sampling_params
|
||||
SamplingParams sampling_params = 5;
|
||||
|
||||
bool log_metrics = 6;
|
||||
|
||||
// Token type IDs for models that require them
|
||||
repeated int32 token_type_ids = 7;
|
||||
|
||||
// Data parallel routing
|
||||
int32 data_parallel_rank = 8;
|
||||
|
||||
// For cross-encoder requests
|
||||
bool is_cross_encoder = 9;
|
||||
repeated string texts = 10; // For cross-encoder batch
|
||||
}
|
||||
|
||||
message EmbedResponse {
|
||||
string request_id = 1;
|
||||
|
||||
oneof response {
|
||||
EmbedComplete complete = 2;
|
||||
EmbedError error = 3;
|
||||
}
|
||||
}
|
||||
|
||||
message EmbedComplete {
|
||||
repeated float embedding = 1;
|
||||
int32 prompt_tokens = 2;
|
||||
int32 cached_tokens = 3;
|
||||
|
||||
// Additional metadata
|
||||
int32 embedding_dim = 4;
|
||||
float generation_time = 5;
|
||||
|
||||
// For batch embeddings
|
||||
repeated Embedding batch_embeddings = 6;
|
||||
}
|
||||
|
||||
message Embedding {
|
||||
repeated float values = 1;
|
||||
int32 index = 2;
|
||||
}
|
||||
|
||||
message EmbedError {
|
||||
string message = 1;
|
||||
string code = 2;
|
||||
string details = 3;
|
||||
}
|
||||
|
||||
// =====================
|
||||
// Management Operations
|
||||
// =====================
|
||||
|
||||
message HealthCheckRequest {
|
||||
// Input for health test generation (must be tokenized)
|
||||
TokenizedInput tokenized = 1;
|
||||
}
|
||||
|
||||
message HealthCheckResponse {
|
||||
bool healthy = 1;
|
||||
string message = 2;
|
||||
}
|
||||
|
||||
message AbortRequest {
|
||||
string request_id = 1;
|
||||
string reason = 2;
|
||||
}
|
||||
|
||||
message AbortResponse {
|
||||
bool success = 1;
|
||||
string message = 2;
|
||||
}
|
||||
|
||||
|
||||
// =====================
|
||||
// Additional Operations (Future)
|
||||
// =====================
|
||||
|
||||
// Load LoRA adapter
|
||||
message LoadLoRARequest {
|
||||
string adapter_id = 1;
|
||||
string adapter_path = 2;
|
||||
int32 rank = 3;
|
||||
}
|
||||
|
||||
message LoadLoRAResponse {
|
||||
bool success = 1;
|
||||
string adapter_id = 2;
|
||||
string message = 3;
|
||||
}
|
||||
|
||||
// Unload LoRA adapter
|
||||
message UnloadLoRARequest {
|
||||
string adapter_id = 1;
|
||||
}
|
||||
|
||||
message UnloadLoRAResponse {
|
||||
bool success = 1;
|
||||
string message = 2;
|
||||
}
|
||||
|
||||
// Update weights
|
||||
message UpdateWeightsRequest {
|
||||
oneof source {
|
||||
string disk_path = 1;
|
||||
bytes tensor_data = 2;
|
||||
string remote_url = 3;
|
||||
}
|
||||
string weight_name = 4;
|
||||
}
|
||||
|
||||
message UpdateWeightsResponse {
|
||||
bool success = 1;
|
||||
string message = 2;
|
||||
}
|
||||
|
||||
// Get internal state for debugging
|
||||
message GetInternalStateRequest {
|
||||
repeated string state_keys = 1;
|
||||
}
|
||||
|
||||
message GetInternalStateResponse {
|
||||
google.protobuf.Struct state = 1;
|
||||
}
|
||||
|
||||
// Set internal state for testing
|
||||
message SetInternalStateRequest {
|
||||
google.protobuf.Struct state = 1;
|
||||
}
|
||||
|
||||
message SetInternalStateResponse {
|
||||
bool success = 1;
|
||||
string message = 2;
|
||||
}
|
||||
106
python/sglang/srt/grpc/sglang_scheduler_pb2.py
Normal file
106
python/sglang/srt/grpc/sglang_scheduler_pb2.py
Normal file
File diff suppressed because one or more lines are too long
427
python/sglang/srt/grpc/sglang_scheduler_pb2.pyi
Normal file
427
python/sglang/srt/grpc/sglang_scheduler_pb2.pyi
Normal file
@@ -0,0 +1,427 @@
|
||||
import datetime
|
||||
|
||||
from google.protobuf import timestamp_pb2 as _timestamp_pb2
|
||||
from google.protobuf import struct_pb2 as _struct_pb2
|
||||
from google.protobuf.internal import containers as _containers
|
||||
from google.protobuf.internal import enum_type_wrapper as _enum_type_wrapper
|
||||
from google.protobuf import descriptor as _descriptor
|
||||
from google.protobuf import message as _message
|
||||
from collections.abc import Iterable as _Iterable, Mapping as _Mapping
|
||||
from typing import ClassVar as _ClassVar, Optional as _Optional, Union as _Union
|
||||
|
||||
DESCRIPTOR: _descriptor.FileDescriptor
|
||||
|
||||
class SamplingParams(_message.Message):
|
||||
__slots__ = ("temperature", "top_p", "top_k", "min_p", "frequency_penalty", "presence_penalty", "repetition_penalty", "max_new_tokens", "stop", "stop_token_ids", "skip_special_tokens", "spaces_between_special_tokens", "regex", "json_schema", "ebnf_grammar", "lora_path", "n", "token_healing", "min_new_tokens", "ignore_eos", "no_stop_trim", "stream_interval", "logit_bias", "structural_tag", "custom_params")
|
||||
class LogitBiasEntry(_message.Message):
|
||||
__slots__ = ("key", "value")
|
||||
KEY_FIELD_NUMBER: _ClassVar[int]
|
||||
VALUE_FIELD_NUMBER: _ClassVar[int]
|
||||
key: str
|
||||
value: float
|
||||
def __init__(self, key: _Optional[str] = ..., value: _Optional[float] = ...) -> None: ...
|
||||
TEMPERATURE_FIELD_NUMBER: _ClassVar[int]
|
||||
TOP_P_FIELD_NUMBER: _ClassVar[int]
|
||||
TOP_K_FIELD_NUMBER: _ClassVar[int]
|
||||
MIN_P_FIELD_NUMBER: _ClassVar[int]
|
||||
FREQUENCY_PENALTY_FIELD_NUMBER: _ClassVar[int]
|
||||
PRESENCE_PENALTY_FIELD_NUMBER: _ClassVar[int]
|
||||
REPETITION_PENALTY_FIELD_NUMBER: _ClassVar[int]
|
||||
MAX_NEW_TOKENS_FIELD_NUMBER: _ClassVar[int]
|
||||
STOP_FIELD_NUMBER: _ClassVar[int]
|
||||
STOP_TOKEN_IDS_FIELD_NUMBER: _ClassVar[int]
|
||||
SKIP_SPECIAL_TOKENS_FIELD_NUMBER: _ClassVar[int]
|
||||
SPACES_BETWEEN_SPECIAL_TOKENS_FIELD_NUMBER: _ClassVar[int]
|
||||
REGEX_FIELD_NUMBER: _ClassVar[int]
|
||||
JSON_SCHEMA_FIELD_NUMBER: _ClassVar[int]
|
||||
EBNF_GRAMMAR_FIELD_NUMBER: _ClassVar[int]
|
||||
LORA_PATH_FIELD_NUMBER: _ClassVar[int]
|
||||
N_FIELD_NUMBER: _ClassVar[int]
|
||||
TOKEN_HEALING_FIELD_NUMBER: _ClassVar[int]
|
||||
MIN_NEW_TOKENS_FIELD_NUMBER: _ClassVar[int]
|
||||
IGNORE_EOS_FIELD_NUMBER: _ClassVar[int]
|
||||
NO_STOP_TRIM_FIELD_NUMBER: _ClassVar[int]
|
||||
STREAM_INTERVAL_FIELD_NUMBER: _ClassVar[int]
|
||||
LOGIT_BIAS_FIELD_NUMBER: _ClassVar[int]
|
||||
STRUCTURAL_TAG_FIELD_NUMBER: _ClassVar[int]
|
||||
CUSTOM_PARAMS_FIELD_NUMBER: _ClassVar[int]
|
||||
temperature: float
|
||||
top_p: float
|
||||
top_k: int
|
||||
min_p: float
|
||||
frequency_penalty: float
|
||||
presence_penalty: float
|
||||
repetition_penalty: float
|
||||
max_new_tokens: int
|
||||
stop: _containers.RepeatedScalarFieldContainer[str]
|
||||
stop_token_ids: _containers.RepeatedScalarFieldContainer[int]
|
||||
skip_special_tokens: bool
|
||||
spaces_between_special_tokens: bool
|
||||
regex: str
|
||||
json_schema: str
|
||||
ebnf_grammar: str
|
||||
lora_path: str
|
||||
n: int
|
||||
token_healing: bool
|
||||
min_new_tokens: int
|
||||
ignore_eos: bool
|
||||
no_stop_trim: bool
|
||||
stream_interval: int
|
||||
logit_bias: _containers.ScalarMap[str, float]
|
||||
structural_tag: str
|
||||
custom_params: _struct_pb2.Struct
|
||||
def __init__(self, temperature: _Optional[float] = ..., top_p: _Optional[float] = ..., top_k: _Optional[int] = ..., min_p: _Optional[float] = ..., frequency_penalty: _Optional[float] = ..., presence_penalty: _Optional[float] = ..., repetition_penalty: _Optional[float] = ..., max_new_tokens: _Optional[int] = ..., stop: _Optional[_Iterable[str]] = ..., stop_token_ids: _Optional[_Iterable[int]] = ..., skip_special_tokens: bool = ..., spaces_between_special_tokens: bool = ..., regex: _Optional[str] = ..., json_schema: _Optional[str] = ..., ebnf_grammar: _Optional[str] = ..., lora_path: _Optional[str] = ..., n: _Optional[int] = ..., token_healing: bool = ..., min_new_tokens: _Optional[int] = ..., ignore_eos: bool = ..., no_stop_trim: bool = ..., stream_interval: _Optional[int] = ..., logit_bias: _Optional[_Mapping[str, float]] = ..., structural_tag: _Optional[str] = ..., custom_params: _Optional[_Union[_struct_pb2.Struct, _Mapping]] = ...) -> None: ...
|
||||
|
||||
class DisaggregatedParams(_message.Message):
|
||||
__slots__ = ("bootstrap_host", "bootstrap_port", "bootstrap_room")
|
||||
BOOTSTRAP_HOST_FIELD_NUMBER: _ClassVar[int]
|
||||
BOOTSTRAP_PORT_FIELD_NUMBER: _ClassVar[int]
|
||||
BOOTSTRAP_ROOM_FIELD_NUMBER: _ClassVar[int]
|
||||
bootstrap_host: str
|
||||
bootstrap_port: int
|
||||
bootstrap_room: int
|
||||
def __init__(self, bootstrap_host: _Optional[str] = ..., bootstrap_port: _Optional[int] = ..., bootstrap_room: _Optional[int] = ...) -> None: ...
|
||||
|
||||
class GenerateRequest(_message.Message):
|
||||
__slots__ = ("request_id", "tokenized", "mm_inputs", "sampling_params", "return_logprob", "logprob_start_len", "top_logprobs_num", "token_ids_logprob", "return_hidden_states", "disaggregated_params", "custom_logit_processor", "timestamp", "log_metrics", "input_embeds", "lora_id", "data_parallel_rank", "dp_balance_id")
|
||||
REQUEST_ID_FIELD_NUMBER: _ClassVar[int]
|
||||
TOKENIZED_FIELD_NUMBER: _ClassVar[int]
|
||||
MM_INPUTS_FIELD_NUMBER: _ClassVar[int]
|
||||
SAMPLING_PARAMS_FIELD_NUMBER: _ClassVar[int]
|
||||
RETURN_LOGPROB_FIELD_NUMBER: _ClassVar[int]
|
||||
LOGPROB_START_LEN_FIELD_NUMBER: _ClassVar[int]
|
||||
TOP_LOGPROBS_NUM_FIELD_NUMBER: _ClassVar[int]
|
||||
TOKEN_IDS_LOGPROB_FIELD_NUMBER: _ClassVar[int]
|
||||
RETURN_HIDDEN_STATES_FIELD_NUMBER: _ClassVar[int]
|
||||
DISAGGREGATED_PARAMS_FIELD_NUMBER: _ClassVar[int]
|
||||
CUSTOM_LOGIT_PROCESSOR_FIELD_NUMBER: _ClassVar[int]
|
||||
TIMESTAMP_FIELD_NUMBER: _ClassVar[int]
|
||||
LOG_METRICS_FIELD_NUMBER: _ClassVar[int]
|
||||
INPUT_EMBEDS_FIELD_NUMBER: _ClassVar[int]
|
||||
LORA_ID_FIELD_NUMBER: _ClassVar[int]
|
||||
DATA_PARALLEL_RANK_FIELD_NUMBER: _ClassVar[int]
|
||||
DP_BALANCE_ID_FIELD_NUMBER: _ClassVar[int]
|
||||
request_id: str
|
||||
tokenized: TokenizedInput
|
||||
mm_inputs: MultimodalInputs
|
||||
sampling_params: SamplingParams
|
||||
return_logprob: bool
|
||||
logprob_start_len: int
|
||||
top_logprobs_num: int
|
||||
token_ids_logprob: _containers.RepeatedScalarFieldContainer[int]
|
||||
return_hidden_states: bool
|
||||
disaggregated_params: DisaggregatedParams
|
||||
custom_logit_processor: str
|
||||
timestamp: _timestamp_pb2.Timestamp
|
||||
log_metrics: bool
|
||||
input_embeds: _containers.RepeatedScalarFieldContainer[float]
|
||||
lora_id: str
|
||||
data_parallel_rank: int
|
||||
dp_balance_id: int
|
||||
def __init__(self, request_id: _Optional[str] = ..., tokenized: _Optional[_Union[TokenizedInput, _Mapping]] = ..., mm_inputs: _Optional[_Union[MultimodalInputs, _Mapping]] = ..., sampling_params: _Optional[_Union[SamplingParams, _Mapping]] = ..., return_logprob: bool = ..., logprob_start_len: _Optional[int] = ..., top_logprobs_num: _Optional[int] = ..., token_ids_logprob: _Optional[_Iterable[int]] = ..., return_hidden_states: bool = ..., disaggregated_params: _Optional[_Union[DisaggregatedParams, _Mapping]] = ..., custom_logit_processor: _Optional[str] = ..., timestamp: _Optional[_Union[datetime.datetime, _timestamp_pb2.Timestamp, _Mapping]] = ..., log_metrics: bool = ..., input_embeds: _Optional[_Iterable[float]] = ..., lora_id: _Optional[str] = ..., data_parallel_rank: _Optional[int] = ..., dp_balance_id: _Optional[int] = ...) -> None: ...
|
||||
|
||||
class TokenizedInput(_message.Message):
|
||||
__slots__ = ("original_text", "input_ids")
|
||||
ORIGINAL_TEXT_FIELD_NUMBER: _ClassVar[int]
|
||||
INPUT_IDS_FIELD_NUMBER: _ClassVar[int]
|
||||
original_text: str
|
||||
input_ids: _containers.RepeatedScalarFieldContainer[int]
|
||||
def __init__(self, original_text: _Optional[str] = ..., input_ids: _Optional[_Iterable[int]] = ...) -> None: ...
|
||||
|
||||
class MultimodalInputs(_message.Message):
|
||||
__slots__ = ("image_urls", "video_urls", "audio_urls", "processed_features", "image_data", "video_data", "audio_data", "modalities")
|
||||
IMAGE_URLS_FIELD_NUMBER: _ClassVar[int]
|
||||
VIDEO_URLS_FIELD_NUMBER: _ClassVar[int]
|
||||
AUDIO_URLS_FIELD_NUMBER: _ClassVar[int]
|
||||
PROCESSED_FEATURES_FIELD_NUMBER: _ClassVar[int]
|
||||
IMAGE_DATA_FIELD_NUMBER: _ClassVar[int]
|
||||
VIDEO_DATA_FIELD_NUMBER: _ClassVar[int]
|
||||
AUDIO_DATA_FIELD_NUMBER: _ClassVar[int]
|
||||
MODALITIES_FIELD_NUMBER: _ClassVar[int]
|
||||
image_urls: _containers.RepeatedScalarFieldContainer[str]
|
||||
video_urls: _containers.RepeatedScalarFieldContainer[str]
|
||||
audio_urls: _containers.RepeatedScalarFieldContainer[str]
|
||||
processed_features: _struct_pb2.Struct
|
||||
image_data: _containers.RepeatedScalarFieldContainer[bytes]
|
||||
video_data: _containers.RepeatedScalarFieldContainer[bytes]
|
||||
audio_data: _containers.RepeatedScalarFieldContainer[bytes]
|
||||
modalities: _containers.RepeatedScalarFieldContainer[str]
|
||||
def __init__(self, image_urls: _Optional[_Iterable[str]] = ..., video_urls: _Optional[_Iterable[str]] = ..., audio_urls: _Optional[_Iterable[str]] = ..., processed_features: _Optional[_Union[_struct_pb2.Struct, _Mapping]] = ..., image_data: _Optional[_Iterable[bytes]] = ..., video_data: _Optional[_Iterable[bytes]] = ..., audio_data: _Optional[_Iterable[bytes]] = ..., modalities: _Optional[_Iterable[str]] = ...) -> None: ...
|
||||
|
||||
class GenerateResponse(_message.Message):
|
||||
__slots__ = ("request_id", "chunk", "complete", "error")
|
||||
REQUEST_ID_FIELD_NUMBER: _ClassVar[int]
|
||||
CHUNK_FIELD_NUMBER: _ClassVar[int]
|
||||
COMPLETE_FIELD_NUMBER: _ClassVar[int]
|
||||
ERROR_FIELD_NUMBER: _ClassVar[int]
|
||||
request_id: str
|
||||
chunk: GenerateStreamChunk
|
||||
complete: GenerateComplete
|
||||
error: GenerateError
|
||||
def __init__(self, request_id: _Optional[str] = ..., chunk: _Optional[_Union[GenerateStreamChunk, _Mapping]] = ..., complete: _Optional[_Union[GenerateComplete, _Mapping]] = ..., error: _Optional[_Union[GenerateError, _Mapping]] = ...) -> None: ...
|
||||
|
||||
class GenerateStreamChunk(_message.Message):
|
||||
__slots__ = ("token_id", "text", "prompt_tokens", "completion_tokens", "cached_tokens", "logprobs", "hidden_states", "generation_time", "queue_time")
|
||||
TOKEN_ID_FIELD_NUMBER: _ClassVar[int]
|
||||
TEXT_FIELD_NUMBER: _ClassVar[int]
|
||||
PROMPT_TOKENS_FIELD_NUMBER: _ClassVar[int]
|
||||
COMPLETION_TOKENS_FIELD_NUMBER: _ClassVar[int]
|
||||
CACHED_TOKENS_FIELD_NUMBER: _ClassVar[int]
|
||||
LOGPROBS_FIELD_NUMBER: _ClassVar[int]
|
||||
HIDDEN_STATES_FIELD_NUMBER: _ClassVar[int]
|
||||
GENERATION_TIME_FIELD_NUMBER: _ClassVar[int]
|
||||
QUEUE_TIME_FIELD_NUMBER: _ClassVar[int]
|
||||
token_id: int
|
||||
text: str
|
||||
prompt_tokens: int
|
||||
completion_tokens: int
|
||||
cached_tokens: int
|
||||
logprobs: LogProbs
|
||||
hidden_states: _containers.RepeatedScalarFieldContainer[float]
|
||||
generation_time: float
|
||||
queue_time: int
|
||||
def __init__(self, token_id: _Optional[int] = ..., text: _Optional[str] = ..., prompt_tokens: _Optional[int] = ..., completion_tokens: _Optional[int] = ..., cached_tokens: _Optional[int] = ..., logprobs: _Optional[_Union[LogProbs, _Mapping]] = ..., hidden_states: _Optional[_Iterable[float]] = ..., generation_time: _Optional[float] = ..., queue_time: _Optional[int] = ...) -> None: ...
|
||||
|
||||
class GenerateComplete(_message.Message):
|
||||
__slots__ = ("output_ids", "output_text", "finish_reason", "all_logprobs", "all_hidden_states")
|
||||
class FinishReason(int, metaclass=_enum_type_wrapper.EnumTypeWrapper):
|
||||
__slots__ = ()
|
||||
STOP: _ClassVar[GenerateComplete.FinishReason]
|
||||
LENGTH: _ClassVar[GenerateComplete.FinishReason]
|
||||
EOS_TOKEN: _ClassVar[GenerateComplete.FinishReason]
|
||||
STOP_STR: _ClassVar[GenerateComplete.FinishReason]
|
||||
ABORT: _ClassVar[GenerateComplete.FinishReason]
|
||||
STOP: GenerateComplete.FinishReason
|
||||
LENGTH: GenerateComplete.FinishReason
|
||||
EOS_TOKEN: GenerateComplete.FinishReason
|
||||
STOP_STR: GenerateComplete.FinishReason
|
||||
ABORT: GenerateComplete.FinishReason
|
||||
OUTPUT_IDS_FIELD_NUMBER: _ClassVar[int]
|
||||
OUTPUT_TEXT_FIELD_NUMBER: _ClassVar[int]
|
||||
FINISH_REASON_FIELD_NUMBER: _ClassVar[int]
|
||||
ALL_LOGPROBS_FIELD_NUMBER: _ClassVar[int]
|
||||
ALL_HIDDEN_STATES_FIELD_NUMBER: _ClassVar[int]
|
||||
output_ids: _containers.RepeatedScalarFieldContainer[int]
|
||||
output_text: str
|
||||
finish_reason: GenerateComplete.FinishReason
|
||||
all_logprobs: _containers.RepeatedCompositeFieldContainer[LogProbs]
|
||||
all_hidden_states: _containers.RepeatedCompositeFieldContainer[HiddenStates]
|
||||
def __init__(self, output_ids: _Optional[_Iterable[int]] = ..., output_text: _Optional[str] = ..., finish_reason: _Optional[_Union[GenerateComplete.FinishReason, str]] = ..., all_logprobs: _Optional[_Iterable[_Union[LogProbs, _Mapping]]] = ..., all_hidden_states: _Optional[_Iterable[_Union[HiddenStates, _Mapping]]] = ...) -> None: ...
|
||||
|
||||
class GenerateError(_message.Message):
|
||||
__slots__ = ("message", "http_status_code", "details")
|
||||
MESSAGE_FIELD_NUMBER: _ClassVar[int]
|
||||
HTTP_STATUS_CODE_FIELD_NUMBER: _ClassVar[int]
|
||||
DETAILS_FIELD_NUMBER: _ClassVar[int]
|
||||
message: str
|
||||
http_status_code: str
|
||||
details: str
|
||||
def __init__(self, message: _Optional[str] = ..., http_status_code: _Optional[str] = ..., details: _Optional[str] = ...) -> None: ...
|
||||
|
||||
class LogProbs(_message.Message):
|
||||
__slots__ = ("token_logprobs", "token_ids", "top_logprobs", "token_texts")
|
||||
TOKEN_LOGPROBS_FIELD_NUMBER: _ClassVar[int]
|
||||
TOKEN_IDS_FIELD_NUMBER: _ClassVar[int]
|
||||
TOP_LOGPROBS_FIELD_NUMBER: _ClassVar[int]
|
||||
TOKEN_TEXTS_FIELD_NUMBER: _ClassVar[int]
|
||||
token_logprobs: _containers.RepeatedScalarFieldContainer[float]
|
||||
token_ids: _containers.RepeatedScalarFieldContainer[int]
|
||||
top_logprobs: _containers.RepeatedCompositeFieldContainer[TopLogProbs]
|
||||
token_texts: _containers.RepeatedScalarFieldContainer[str]
|
||||
def __init__(self, token_logprobs: _Optional[_Iterable[float]] = ..., token_ids: _Optional[_Iterable[int]] = ..., top_logprobs: _Optional[_Iterable[_Union[TopLogProbs, _Mapping]]] = ..., token_texts: _Optional[_Iterable[str]] = ...) -> None: ...
|
||||
|
||||
class TopLogProbs(_message.Message):
|
||||
__slots__ = ("values", "token_ids", "token_texts")
|
||||
VALUES_FIELD_NUMBER: _ClassVar[int]
|
||||
TOKEN_IDS_FIELD_NUMBER: _ClassVar[int]
|
||||
TOKEN_TEXTS_FIELD_NUMBER: _ClassVar[int]
|
||||
values: _containers.RepeatedScalarFieldContainer[float]
|
||||
token_ids: _containers.RepeatedScalarFieldContainer[int]
|
||||
token_texts: _containers.RepeatedScalarFieldContainer[str]
|
||||
def __init__(self, values: _Optional[_Iterable[float]] = ..., token_ids: _Optional[_Iterable[int]] = ..., token_texts: _Optional[_Iterable[str]] = ...) -> None: ...
|
||||
|
||||
class HiddenStates(_message.Message):
|
||||
__slots__ = ("values", "layer", "position")
|
||||
VALUES_FIELD_NUMBER: _ClassVar[int]
|
||||
LAYER_FIELD_NUMBER: _ClassVar[int]
|
||||
POSITION_FIELD_NUMBER: _ClassVar[int]
|
||||
values: _containers.RepeatedScalarFieldContainer[float]
|
||||
layer: int
|
||||
position: int
|
||||
def __init__(self, values: _Optional[_Iterable[float]] = ..., layer: _Optional[int] = ..., position: _Optional[int] = ...) -> None: ...
|
||||
|
||||
class EmbedRequest(_message.Message):
|
||||
__slots__ = ("request_id", "tokenized", "mm_inputs", "sampling_params", "log_metrics", "token_type_ids", "data_parallel_rank", "is_cross_encoder", "texts")
|
||||
REQUEST_ID_FIELD_NUMBER: _ClassVar[int]
|
||||
TOKENIZED_FIELD_NUMBER: _ClassVar[int]
|
||||
MM_INPUTS_FIELD_NUMBER: _ClassVar[int]
|
||||
SAMPLING_PARAMS_FIELD_NUMBER: _ClassVar[int]
|
||||
LOG_METRICS_FIELD_NUMBER: _ClassVar[int]
|
||||
TOKEN_TYPE_IDS_FIELD_NUMBER: _ClassVar[int]
|
||||
DATA_PARALLEL_RANK_FIELD_NUMBER: _ClassVar[int]
|
||||
IS_CROSS_ENCODER_FIELD_NUMBER: _ClassVar[int]
|
||||
TEXTS_FIELD_NUMBER: _ClassVar[int]
|
||||
request_id: str
|
||||
tokenized: TokenizedInput
|
||||
mm_inputs: MultimodalInputs
|
||||
sampling_params: SamplingParams
|
||||
log_metrics: bool
|
||||
token_type_ids: _containers.RepeatedScalarFieldContainer[int]
|
||||
data_parallel_rank: int
|
||||
is_cross_encoder: bool
|
||||
texts: _containers.RepeatedScalarFieldContainer[str]
|
||||
def __init__(self, request_id: _Optional[str] = ..., tokenized: _Optional[_Union[TokenizedInput, _Mapping]] = ..., mm_inputs: _Optional[_Union[MultimodalInputs, _Mapping]] = ..., sampling_params: _Optional[_Union[SamplingParams, _Mapping]] = ..., log_metrics: bool = ..., token_type_ids: _Optional[_Iterable[int]] = ..., data_parallel_rank: _Optional[int] = ..., is_cross_encoder: bool = ..., texts: _Optional[_Iterable[str]] = ...) -> None: ...
|
||||
|
||||
class EmbedResponse(_message.Message):
|
||||
__slots__ = ("request_id", "complete", "error")
|
||||
REQUEST_ID_FIELD_NUMBER: _ClassVar[int]
|
||||
COMPLETE_FIELD_NUMBER: _ClassVar[int]
|
||||
ERROR_FIELD_NUMBER: _ClassVar[int]
|
||||
request_id: str
|
||||
complete: EmbedComplete
|
||||
error: EmbedError
|
||||
def __init__(self, request_id: _Optional[str] = ..., complete: _Optional[_Union[EmbedComplete, _Mapping]] = ..., error: _Optional[_Union[EmbedError, _Mapping]] = ...) -> None: ...
|
||||
|
||||
class EmbedComplete(_message.Message):
|
||||
__slots__ = ("embedding", "prompt_tokens", "cached_tokens", "embedding_dim", "generation_time", "batch_embeddings")
|
||||
EMBEDDING_FIELD_NUMBER: _ClassVar[int]
|
||||
PROMPT_TOKENS_FIELD_NUMBER: _ClassVar[int]
|
||||
CACHED_TOKENS_FIELD_NUMBER: _ClassVar[int]
|
||||
EMBEDDING_DIM_FIELD_NUMBER: _ClassVar[int]
|
||||
GENERATION_TIME_FIELD_NUMBER: _ClassVar[int]
|
||||
BATCH_EMBEDDINGS_FIELD_NUMBER: _ClassVar[int]
|
||||
embedding: _containers.RepeatedScalarFieldContainer[float]
|
||||
prompt_tokens: int
|
||||
cached_tokens: int
|
||||
embedding_dim: int
|
||||
generation_time: float
|
||||
batch_embeddings: _containers.RepeatedCompositeFieldContainer[Embedding]
|
||||
def __init__(self, embedding: _Optional[_Iterable[float]] = ..., prompt_tokens: _Optional[int] = ..., cached_tokens: _Optional[int] = ..., embedding_dim: _Optional[int] = ..., generation_time: _Optional[float] = ..., batch_embeddings: _Optional[_Iterable[_Union[Embedding, _Mapping]]] = ...) -> None: ...
|
||||
|
||||
class Embedding(_message.Message):
|
||||
__slots__ = ("values", "index")
|
||||
VALUES_FIELD_NUMBER: _ClassVar[int]
|
||||
INDEX_FIELD_NUMBER: _ClassVar[int]
|
||||
values: _containers.RepeatedScalarFieldContainer[float]
|
||||
index: int
|
||||
def __init__(self, values: _Optional[_Iterable[float]] = ..., index: _Optional[int] = ...) -> None: ...
|
||||
|
||||
class EmbedError(_message.Message):
|
||||
__slots__ = ("message", "code", "details")
|
||||
MESSAGE_FIELD_NUMBER: _ClassVar[int]
|
||||
CODE_FIELD_NUMBER: _ClassVar[int]
|
||||
DETAILS_FIELD_NUMBER: _ClassVar[int]
|
||||
message: str
|
||||
code: str
|
||||
details: str
|
||||
def __init__(self, message: _Optional[str] = ..., code: _Optional[str] = ..., details: _Optional[str] = ...) -> None: ...
|
||||
|
||||
class HealthCheckRequest(_message.Message):
|
||||
__slots__ = ("tokenized",)
|
||||
TOKENIZED_FIELD_NUMBER: _ClassVar[int]
|
||||
tokenized: TokenizedInput
|
||||
def __init__(self, tokenized: _Optional[_Union[TokenizedInput, _Mapping]] = ...) -> None: ...
|
||||
|
||||
class HealthCheckResponse(_message.Message):
|
||||
__slots__ = ("healthy", "message")
|
||||
HEALTHY_FIELD_NUMBER: _ClassVar[int]
|
||||
MESSAGE_FIELD_NUMBER: _ClassVar[int]
|
||||
healthy: bool
|
||||
message: str
|
||||
def __init__(self, healthy: bool = ..., message: _Optional[str] = ...) -> None: ...
|
||||
|
||||
class AbortRequest(_message.Message):
|
||||
__slots__ = ("request_id", "reason")
|
||||
REQUEST_ID_FIELD_NUMBER: _ClassVar[int]
|
||||
REASON_FIELD_NUMBER: _ClassVar[int]
|
||||
request_id: str
|
||||
reason: str
|
||||
def __init__(self, request_id: _Optional[str] = ..., reason: _Optional[str] = ...) -> None: ...
|
||||
|
||||
class AbortResponse(_message.Message):
|
||||
__slots__ = ("success", "message")
|
||||
SUCCESS_FIELD_NUMBER: _ClassVar[int]
|
||||
MESSAGE_FIELD_NUMBER: _ClassVar[int]
|
||||
success: bool
|
||||
message: str
|
||||
def __init__(self, success: bool = ..., message: _Optional[str] = ...) -> None: ...
|
||||
|
||||
class LoadLoRARequest(_message.Message):
|
||||
__slots__ = ("adapter_id", "adapter_path", "rank")
|
||||
ADAPTER_ID_FIELD_NUMBER: _ClassVar[int]
|
||||
ADAPTER_PATH_FIELD_NUMBER: _ClassVar[int]
|
||||
RANK_FIELD_NUMBER: _ClassVar[int]
|
||||
adapter_id: str
|
||||
adapter_path: str
|
||||
rank: int
|
||||
def __init__(self, adapter_id: _Optional[str] = ..., adapter_path: _Optional[str] = ..., rank: _Optional[int] = ...) -> None: ...
|
||||
|
||||
class LoadLoRAResponse(_message.Message):
|
||||
__slots__ = ("success", "adapter_id", "message")
|
||||
SUCCESS_FIELD_NUMBER: _ClassVar[int]
|
||||
ADAPTER_ID_FIELD_NUMBER: _ClassVar[int]
|
||||
MESSAGE_FIELD_NUMBER: _ClassVar[int]
|
||||
success: bool
|
||||
adapter_id: str
|
||||
message: str
|
||||
def __init__(self, success: bool = ..., adapter_id: _Optional[str] = ..., message: _Optional[str] = ...) -> None: ...
|
||||
|
||||
class UnloadLoRARequest(_message.Message):
|
||||
__slots__ = ("adapter_id",)
|
||||
ADAPTER_ID_FIELD_NUMBER: _ClassVar[int]
|
||||
adapter_id: str
|
||||
def __init__(self, adapter_id: _Optional[str] = ...) -> None: ...
|
||||
|
||||
class UnloadLoRAResponse(_message.Message):
|
||||
__slots__ = ("success", "message")
|
||||
SUCCESS_FIELD_NUMBER: _ClassVar[int]
|
||||
MESSAGE_FIELD_NUMBER: _ClassVar[int]
|
||||
success: bool
|
||||
message: str
|
||||
def __init__(self, success: bool = ..., message: _Optional[str] = ...) -> None: ...
|
||||
|
||||
class UpdateWeightsRequest(_message.Message):
|
||||
__slots__ = ("disk_path", "tensor_data", "remote_url", "weight_name")
|
||||
DISK_PATH_FIELD_NUMBER: _ClassVar[int]
|
||||
TENSOR_DATA_FIELD_NUMBER: _ClassVar[int]
|
||||
REMOTE_URL_FIELD_NUMBER: _ClassVar[int]
|
||||
WEIGHT_NAME_FIELD_NUMBER: _ClassVar[int]
|
||||
disk_path: str
|
||||
tensor_data: bytes
|
||||
remote_url: str
|
||||
weight_name: str
|
||||
def __init__(self, disk_path: _Optional[str] = ..., tensor_data: _Optional[bytes] = ..., remote_url: _Optional[str] = ..., weight_name: _Optional[str] = ...) -> None: ...
|
||||
|
||||
class UpdateWeightsResponse(_message.Message):
|
||||
__slots__ = ("success", "message")
|
||||
SUCCESS_FIELD_NUMBER: _ClassVar[int]
|
||||
MESSAGE_FIELD_NUMBER: _ClassVar[int]
|
||||
success: bool
|
||||
message: str
|
||||
def __init__(self, success: bool = ..., message: _Optional[str] = ...) -> None: ...
|
||||
|
||||
class GetInternalStateRequest(_message.Message):
|
||||
__slots__ = ("state_keys",)
|
||||
STATE_KEYS_FIELD_NUMBER: _ClassVar[int]
|
||||
state_keys: _containers.RepeatedScalarFieldContainer[str]
|
||||
def __init__(self, state_keys: _Optional[_Iterable[str]] = ...) -> None: ...
|
||||
|
||||
class GetInternalStateResponse(_message.Message):
|
||||
__slots__ = ("state",)
|
||||
STATE_FIELD_NUMBER: _ClassVar[int]
|
||||
state: _struct_pb2.Struct
|
||||
def __init__(self, state: _Optional[_Union[_struct_pb2.Struct, _Mapping]] = ...) -> None: ...
|
||||
|
||||
class SetInternalStateRequest(_message.Message):
|
||||
__slots__ = ("state",)
|
||||
STATE_FIELD_NUMBER: _ClassVar[int]
|
||||
state: _struct_pb2.Struct
|
||||
def __init__(self, state: _Optional[_Union[_struct_pb2.Struct, _Mapping]] = ...) -> None: ...
|
||||
|
||||
class SetInternalStateResponse(_message.Message):
|
||||
__slots__ = ("success", "message")
|
||||
SUCCESS_FIELD_NUMBER: _ClassVar[int]
|
||||
MESSAGE_FIELD_NUMBER: _ClassVar[int]
|
||||
success: bool
|
||||
message: str
|
||||
def __init__(self, success: bool = ..., message: _Optional[str] = ...) -> None: ...
|
||||
236
python/sglang/srt/grpc/sglang_scheduler_pb2_grpc.py
Normal file
236
python/sglang/srt/grpc/sglang_scheduler_pb2_grpc.py
Normal file
@@ -0,0 +1,236 @@
|
||||
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
|
||||
"""Client and server classes corresponding to protobuf-defined services."""
|
||||
import grpc
|
||||
import warnings
|
||||
|
||||
from . import sglang_scheduler_pb2 as sglang__scheduler__pb2
|
||||
|
||||
GRPC_GENERATED_VERSION = '1.74.0'
|
||||
GRPC_VERSION = grpc.__version__
|
||||
_version_not_supported = False
|
||||
|
||||
try:
|
||||
from grpc._utilities import first_version_is_lower
|
||||
_version_not_supported = first_version_is_lower(GRPC_VERSION, GRPC_GENERATED_VERSION)
|
||||
except ImportError:
|
||||
_version_not_supported = True
|
||||
|
||||
if _version_not_supported:
|
||||
raise RuntimeError(
|
||||
f'The grpc package installed is at version {GRPC_VERSION},'
|
||||
+ f' but the generated code in sglang_scheduler_pb2_grpc.py depends on'
|
||||
+ f' grpcio>={GRPC_GENERATED_VERSION}.'
|
||||
+ f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}'
|
||||
+ f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.'
|
||||
)
|
||||
|
||||
|
||||
class SglangSchedulerStub(object):
|
||||
"""Service definition for SGLang scheduler communication
|
||||
This protocol bridges the Rust router and Python scheduler
|
||||
"""
|
||||
|
||||
def __init__(self, channel):
|
||||
"""Constructor.
|
||||
|
||||
Args:
|
||||
channel: A grpc.Channel.
|
||||
"""
|
||||
self.Generate = channel.unary_stream(
|
||||
'/sglang.grpc.scheduler.SglangScheduler/Generate',
|
||||
request_serializer=sglang__scheduler__pb2.GenerateRequest.SerializeToString,
|
||||
response_deserializer=sglang__scheduler__pb2.GenerateResponse.FromString,
|
||||
_registered_method=True)
|
||||
self.Embed = channel.unary_unary(
|
||||
'/sglang.grpc.scheduler.SglangScheduler/Embed',
|
||||
request_serializer=sglang__scheduler__pb2.EmbedRequest.SerializeToString,
|
||||
response_deserializer=sglang__scheduler__pb2.EmbedResponse.FromString,
|
||||
_registered_method=True)
|
||||
self.HealthCheck = channel.unary_unary(
|
||||
'/sglang.grpc.scheduler.SglangScheduler/HealthCheck',
|
||||
request_serializer=sglang__scheduler__pb2.HealthCheckRequest.SerializeToString,
|
||||
response_deserializer=sglang__scheduler__pb2.HealthCheckResponse.FromString,
|
||||
_registered_method=True)
|
||||
self.Abort = channel.unary_unary(
|
||||
'/sglang.grpc.scheduler.SglangScheduler/Abort',
|
||||
request_serializer=sglang__scheduler__pb2.AbortRequest.SerializeToString,
|
||||
response_deserializer=sglang__scheduler__pb2.AbortResponse.FromString,
|
||||
_registered_method=True)
|
||||
|
||||
|
||||
class SglangSchedulerServicer(object):
|
||||
"""Service definition for SGLang scheduler communication
|
||||
This protocol bridges the Rust router and Python scheduler
|
||||
"""
|
||||
|
||||
def Generate(self, request, context):
|
||||
"""Submit a generation request (supports streaming)
|
||||
"""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def Embed(self, request, context):
|
||||
"""Submit an embedding request
|
||||
"""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def HealthCheck(self, request, context):
|
||||
"""Health check and metrics
|
||||
"""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def Abort(self, request, context):
|
||||
"""Abort a running request
|
||||
"""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
|
||||
def add_SglangSchedulerServicer_to_server(servicer, server):
|
||||
rpc_method_handlers = {
|
||||
'Generate': grpc.unary_stream_rpc_method_handler(
|
||||
servicer.Generate,
|
||||
request_deserializer=sglang__scheduler__pb2.GenerateRequest.FromString,
|
||||
response_serializer=sglang__scheduler__pb2.GenerateResponse.SerializeToString,
|
||||
),
|
||||
'Embed': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.Embed,
|
||||
request_deserializer=sglang__scheduler__pb2.EmbedRequest.FromString,
|
||||
response_serializer=sglang__scheduler__pb2.EmbedResponse.SerializeToString,
|
||||
),
|
||||
'HealthCheck': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.HealthCheck,
|
||||
request_deserializer=sglang__scheduler__pb2.HealthCheckRequest.FromString,
|
||||
response_serializer=sglang__scheduler__pb2.HealthCheckResponse.SerializeToString,
|
||||
),
|
||||
'Abort': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.Abort,
|
||||
request_deserializer=sglang__scheduler__pb2.AbortRequest.FromString,
|
||||
response_serializer=sglang__scheduler__pb2.AbortResponse.SerializeToString,
|
||||
),
|
||||
}
|
||||
generic_handler = grpc.method_handlers_generic_handler(
|
||||
'sglang.grpc.scheduler.SglangScheduler', rpc_method_handlers)
|
||||
server.add_generic_rpc_handlers((generic_handler,))
|
||||
server.add_registered_method_handlers('sglang.grpc.scheduler.SglangScheduler', rpc_method_handlers)
|
||||
|
||||
|
||||
# This class is part of an EXPERIMENTAL API.
|
||||
class SglangScheduler(object):
|
||||
"""Service definition for SGLang scheduler communication
|
||||
This protocol bridges the Rust router and Python scheduler
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def Generate(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_stream(
|
||||
request,
|
||||
target,
|
||||
'/sglang.grpc.scheduler.SglangScheduler/Generate',
|
||||
sglang__scheduler__pb2.GenerateRequest.SerializeToString,
|
||||
sglang__scheduler__pb2.GenerateResponse.FromString,
|
||||
options,
|
||||
channel_credentials,
|
||||
insecure,
|
||||
call_credentials,
|
||||
compression,
|
||||
wait_for_ready,
|
||||
timeout,
|
||||
metadata,
|
||||
_registered_method=True)
|
||||
|
||||
@staticmethod
|
||||
def Embed(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_unary(
|
||||
request,
|
||||
target,
|
||||
'/sglang.grpc.scheduler.SglangScheduler/Embed',
|
||||
sglang__scheduler__pb2.EmbedRequest.SerializeToString,
|
||||
sglang__scheduler__pb2.EmbedResponse.FromString,
|
||||
options,
|
||||
channel_credentials,
|
||||
insecure,
|
||||
call_credentials,
|
||||
compression,
|
||||
wait_for_ready,
|
||||
timeout,
|
||||
metadata,
|
||||
_registered_method=True)
|
||||
|
||||
@staticmethod
|
||||
def HealthCheck(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_unary(
|
||||
request,
|
||||
target,
|
||||
'/sglang.grpc.scheduler.SglangScheduler/HealthCheck',
|
||||
sglang__scheduler__pb2.HealthCheckRequest.SerializeToString,
|
||||
sglang__scheduler__pb2.HealthCheckResponse.FromString,
|
||||
options,
|
||||
channel_credentials,
|
||||
insecure,
|
||||
call_credentials,
|
||||
compression,
|
||||
wait_for_ready,
|
||||
timeout,
|
||||
metadata,
|
||||
_registered_method=True)
|
||||
|
||||
@staticmethod
|
||||
def Abort(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_unary(
|
||||
request,
|
||||
target,
|
||||
'/sglang.grpc.scheduler.SglangScheduler/Abort',
|
||||
sglang__scheduler__pb2.AbortRequest.SerializeToString,
|
||||
sglang__scheduler__pb2.AbortResponse.FromString,
|
||||
options,
|
||||
channel_credentials,
|
||||
insecure,
|
||||
call_credentials,
|
||||
compression,
|
||||
wait_for_ready,
|
||||
timeout,
|
||||
metadata,
|
||||
_registered_method=True)
|
||||
@@ -2238,6 +2238,7 @@ class ServerArgs:
|
||||
args.pp_size = args.pipeline_parallel_size
|
||||
args.dp_size = args.data_parallel_size
|
||||
args.ep_size = args.expert_parallel_size
|
||||
|
||||
attrs = [attr.name for attr in dataclasses.fields(cls)]
|
||||
return cls(**{attr: getattr(args, attr) for attr in attrs})
|
||||
|
||||
|
||||
Reference in New Issue
Block a user