diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index a295f2eb4..10b05c204 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -22,17 +22,19 @@ repos: rev: 5.13.2 hooks: - id: isort + exclude: '^python/sglang/srt/grpc/.*_pb2\.py$|^python/sglang/srt/grpc/.*_pb2_grpc\.py$|^python/sglang/srt/grpc/.*_pb2\.pyi$|^python/sglang/srt/grpc/.*_pb2_grpc\.pyi$' - repo: https://github.com/astral-sh/ruff-pre-commit rev: v0.11.7 hooks: - id: ruff args: [--select=F401, --fixable=F401] files: ^(benchmark/|docs/|examples/) - exclude: \.ipynb$ + exclude: \.ipynb$|^python/sglang/srt/grpc/.*_pb2\.py$|^python/sglang/srt/grpc/.*_pb2_grpc\.py$|^python/sglang/srt/grpc/.*_pb2\.pyi$|^python/sglang/srt/grpc/.*_pb2_grpc\.pyi$ - repo: https://github.com/psf/black rev: 24.10.0 hooks: - id: black-jupyter + exclude: '^python/sglang/srt/grpc/.*_pb2\.py$|^python/sglang/srt/grpc/.*_pb2_grpc\.py$|^python/sglang/srt/grpc/.*_pb2\.pyi$|^python/sglang/srt/grpc/.*_pb2_grpc\.pyi$' - repo: https://github.com/codespell-project/codespell rev: v2.4.1 hooks: @@ -42,7 +44,11 @@ repos: exclude: | (?x)^( test/srt/test_reasoning_parser\.py| - docs/advanced_features/vlm_query\.ipynb + docs/advanced_features/vlm_query\.ipynb| + python/sglang/srt/grpc/.*_pb2\.py| + python/sglang/srt/grpc/.*_pb2_grpc\.py| + python/sglang/srt/grpc/.*_pb2\.pyi| + python/sglang/srt/grpc/.*_pb2_grpc\.pyi )$ - repo: https://github.com/pre-commit/mirrors-clang-format rev: v18.1.8 diff --git a/python/sglang/srt/entrypoints/grpc_request_manager.py b/python/sglang/srt/entrypoints/grpc_request_manager.py new file mode 100644 index 000000000..91c1d9e31 --- /dev/null +++ b/python/sglang/srt/entrypoints/grpc_request_manager.py @@ -0,0 +1,580 @@ +""" +gRPC Request Manager - Orchestrates request lifecycle without tokenization. +Mimics TokenizerManager's state management and ZMQ communication patterns. +""" + +import asyncio +import dataclasses +import logging +import os +import signal +import sys +import threading +import time +from typing import Any, Dict, List, Optional, Union + +import grpc +import zmq +import zmq.asyncio + +from sglang.srt.managers.io_struct import ( + AbortReq, + BatchEmbeddingOut, + BatchTokenIDOut, + HealthCheckOutput, + TokenizedEmbeddingReqInput, + TokenizedGenerateReqInput, +) +from sglang.srt.server_args import PortArgs, ServerArgs +from sglang.srt.utils import get_zmq_socket, kill_process_tree +from sglang.utils import get_exception_traceback + +logger = logging.getLogger(__name__) + + +class GrpcSignalHandler: + """Minimal signal handler for gRPC server - delegates real crash handling to scheduler.""" + + def __init__(self, grpc_manager): + self.grpc_manager = grpc_manager + + def sigterm_handler(self, signum=None, frame=None): + """Handle SIGTERM by gracefully shutting down gRPC server.""" + logger.warning( + f"SIGTERM received. {signum=} {frame=}. Shutting down gRPC server..." + ) + self.grpc_manager.gracefully_exit = True + + def running_phase_sigquit_handler(self, signum=None, frame=None): + """Handle SIGQUIT from failed scheduler process.""" + logger.error( + "Received SIGQUIT from scheduler process. Scheduler failed, shutting down gRPC server." + ) + logger.info( + "Note: Crash dumps are handled by the scheduler process, not the gRPC server." + ) + # Just exit cleanly - the scheduler handles crash dumps + kill_process_tree(os.getpid(), include_parent=True) + + +@dataclasses.dataclass +class GrpcReqState: + """State tracking for a gRPC request.""" + + # Request identification + request_id: str + grpc_context: Optional[grpc.aio.ServicerContext] + + # Communication + out_queue: asyncio.Queue + finished: bool + event: asyncio.Event + obj: Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput] + + # Metrics (same as TokenizerManager's ReqState) + created_time: float + finished_time: float = 0.0 + first_token_time: float = 0.0 + last_time: float = 0.0 + last_completion_tokens: int = 1 + + # Streaming state + last_output_offset: int = 0 + stream_finished: bool = False + + # Output accumulation + text: str = "" + output_ids: List[int] = dataclasses.field(default_factory=list) + input_token_logprobs_val: List[float] = dataclasses.field(default_factory=list) + input_token_logprobs_idx: List[int] = dataclasses.field(default_factory=list) + output_token_logprobs_val: List[float] = dataclasses.field(default_factory=list) + output_token_logprobs_idx: List[int] = dataclasses.field(default_factory=list) + input_top_logprobs_val: List[List[float]] = dataclasses.field(default_factory=list) + input_top_logprobs_idx: List[List[int]] = dataclasses.field(default_factory=list) + output_top_logprobs_val: List[List[float]] = dataclasses.field(default_factory=list) + output_top_logprobs_idx: List[List[int]] = dataclasses.field(default_factory=list) + + # Session state + session_id: Optional[str] = None + is_session_request: bool = False + + +class GrpcRequestManager: + """ + Manages gRPC request lifecycle, mimicking TokenizerManager's orchestration + behaviors without tokenization. + """ + + def __init__( + self, + server_args: ServerArgs, + port_args: PortArgs, + ): + """Initialize the gRPC request manager.""" + self.server_args = server_args + self.port_args = port_args + + # ZMQ Communication Setup (same pattern as TokenizerManager) + context = zmq.asyncio.Context(2) + + # Socket for receiving outputs from scheduler + self.recv_from_scheduler = get_zmq_socket( + context, zmq.PULL, port_args.detokenizer_ipc_name, bind=True + ) + + # Socket for sending requests to scheduler + self.send_to_scheduler = get_zmq_socket( + context, zmq.PUSH, port_args.scheduler_input_ipc_name, bind=True + ) + + # State Management (from TokenizerManager) + self.rid_to_state: Dict[str, GrpcReqState] = {} + self.asyncio_tasks: set = set() + self.gracefully_exit = False + self.no_create_loop = False + self.event_loop = None + + # Pause/Resume Control + self.is_pause = False + self.is_pause_cond = asyncio.Condition() + + # Metrics + self.request_counter = 0 + self.request_counter_lock = asyncio.Lock() + self.last_receive_tstamp = time.time() + + # Crash dump for debugging + self.crash_dump_request_list = [] + self.crash_dump_performed = False + + logger.info( + f"GrpcRequestManager initialized with ZMQ IPC: " + f"recv={port_args.detokenizer_ipc_name}, " + f"send={port_args.scheduler_input_ipc_name}" + ) + + async def generate_request( + self, + obj: TokenizedGenerateReqInput, + request_id: Optional[str] = None, + grpc_context: Optional[grpc.aio.ServicerContext] = None, + ) -> asyncio.Queue: + """ + Submit a generation request to the scheduler. + Returns a queue for streaming outputs. + """ + # Generate request ID if not provided + if request_id is None: + async with self.request_counter_lock: + request_id = f"grpc-{self.request_counter}" + self.request_counter += 1 + + obj.rid = request_id + + # TODO: support log_request + + # Create request state + state = GrpcReqState( + request_id=request_id, + grpc_context=grpc_context, + out_queue=asyncio.Queue(), + finished=False, + event=asyncio.Event(), + obj=obj, + created_time=time.time(), + ) + + # Track session if needed + if hasattr(obj, "session_params") and obj.session_params: + state.session_id = obj.session_params.session_id + state.is_session_request = True + + # Register state + self.rid_to_state[request_id] = state + self.record_request_for_crash_dump(obj) + + # Send to scheduler via ZMQ + try: + await self._send_to_scheduler(obj) + except Exception as e: + # Clean up on failure + del self.rid_to_state[request_id] + raise RuntimeError(f"Failed to send request to scheduler: {e}") + + return state.out_queue + + async def embedding_request( + self, + obj: TokenizedEmbeddingReqInput, + request_id: Optional[str] = None, + ) -> asyncio.Future: + """ + Submit an embedding request to the scheduler. + Returns a future that will contain the embedding result. + """ + # Generate request ID if not provided + if request_id is None: + async with self.request_counter_lock: + request_id = f"grpc-embed-{self.request_counter}" + self.request_counter += 1 + + obj.rid = request_id + + # Create request state + state = GrpcReqState( + request_id=request_id, + grpc_context=None, + out_queue=asyncio.Queue(), + finished=False, + event=asyncio.Event(), + obj=obj, + created_time=time.time(), + ) + + # Register state + self.rid_to_state[request_id] = state + + # Create future for result + future = asyncio.Future() + + # Send to scheduler + try: + await self._send_to_scheduler(obj) + except Exception as e: + del self.rid_to_state[request_id] + future.set_exception(e) + return future + + # Wait for result in background + async def wait_for_result(): + try: + # Wait for completion + await state.event.wait() + # Get result from queue + result = await state.out_queue.get() + future.set_result(result) + except Exception as e: + future.set_exception(e) + finally: + # Clean up + if request_id in self.rid_to_state: + del self.rid_to_state[request_id] + + asyncio.create_task(wait_for_result()) + return future + + async def abort_request(self, request_id: str) -> bool: + """Abort a running request.""" + if request_id not in self.rid_to_state: + return False + + # Send abort to scheduler + abort_req = AbortReq(rid=request_id) + try: + await self._send_to_scheduler(abort_req) + except Exception as e: + logger.error(f"Failed to send abort request: {e}") + return False + + # Mark as finished + state = self.rid_to_state.get(request_id) + if state: + state.finished = True + state.stream_finished = True + state.event.set() + + # Send abort notification to output queue + await state.out_queue.put({"error": "Request aborted", "abort": True}) + + return True + + async def pause_generation(self): + """Pause generation processing.""" + async with self.is_pause_cond: + self.is_pause = True + logger.info("Generation paused") + + async def resume_generation(self): + """Resume generation processing.""" + async with self.is_pause_cond: + self.is_pause = False + self.is_pause_cond.notify_all() + logger.info("Generation resumed") + + async def handle_loop(self): + """ + Main event loop - processes outputs from scheduler. + Mimics TokenizerManager's handle_loop. + """ + while not self.gracefully_exit: + try: + # Receive from scheduler + recv_obj = await self.recv_from_scheduler.recv_pyobj() + self.last_receive_tstamp = time.time() + + # Check for pause + async with self.is_pause_cond: + while self.is_pause: + await self.is_pause_cond.wait() + + # Handle different output types + if isinstance(recv_obj, BatchTokenIDOut): + await self._handle_batch_output(recv_obj) + elif isinstance(recv_obj, BatchEmbeddingOut): + await self._handle_embedding_output(recv_obj) + elif isinstance(recv_obj, HealthCheckOutput): + await self._handle_health_check_output(recv_obj) + else: + logger.warning(f"Unknown output type: {type(recv_obj)}") + + except zmq.error.Again: + # Timeout, check if we should exit + if self.gracefully_exit: + break + continue + except Exception as e: + logger.error(f"Handle loop error: {e}\n{get_exception_traceback()}") + if self.gracefully_exit: + break + + async def _handle_batch_output(self, batch_out: BatchTokenIDOut): + """Handle batch generation output from scheduler.""" + # Process each request in the batch + for i, rid in enumerate(batch_out.rids): + if rid not in self.rid_to_state: + continue + + state = self.rid_to_state[rid] + + # Update metrics + now = time.time() + if state.first_token_time == 0.0: + state.first_token_time = now + state.last_time = now + + # Extract output for this request + output_data = { + "request_id": rid, + "text": batch_out.decoded_texts[i] if batch_out.decoded_texts else "", + "token_ids": batch_out.output_ids[i] if batch_out.output_ids else [], + "finished": batch_out.finished_reasons[i] is not None, + "meta_info": { + "prompt_tokens": ( + batch_out.prompt_tokens[i] if batch_out.prompt_tokens else 0 + ), + "completion_tokens": ( + batch_out.completion_tokens[i] + if batch_out.completion_tokens + else 0 + ), + "finish_reason": ( + str(batch_out.finished_reasons[i]) + if batch_out.finished_reasons[i] + else None + ), + }, + } + + # Add logprobs if available + if batch_out.output_token_logprobs_val and i < len( + batch_out.output_token_logprobs_val + ): + output_data["logprobs"] = { + "tokens": batch_out.output_token_logprobs_val[i], + "top_logprobs": ( + batch_out.output_top_logprobs_val[i] + if batch_out.output_top_logprobs_val + and i < len(batch_out.output_top_logprobs_val) + else None + ), + } + + # Update state + if output_data["text"]: + state.text += output_data["text"][state.last_output_offset :] + state.last_output_offset = len(output_data["text"]) + + if output_data["token_ids"]: + state.output_ids.extend(output_data["token_ids"]) + + # Send to output queue + await state.out_queue.put(output_data) + + # Handle completion + if output_data["finished"]: + state.finished = True + state.finished_time = now + state.stream_finished = True + state.event.set() + + # Remove from tracking after a delay + async def cleanup(): + await asyncio.sleep(5.0) + if rid in self.rid_to_state: + del self.rid_to_state[rid] + + asyncio.create_task(cleanup()) + + async def _handle_embedding_output(self, batch_out: BatchEmbeddingOut): + """Handle batch embedding output from scheduler.""" + for i, rid in enumerate(batch_out.rids): + if rid not in self.rid_to_state: + continue + + state = self.rid_to_state[rid] + + # Create result + result = { + "request_id": rid, + "embedding": batch_out.embeddings[i], + "prompt_tokens": ( + batch_out.prompt_tokens[i] if batch_out.prompt_tokens else 0 + ), + "finish_reason": ( + batch_out.finish_reason[i] if batch_out.finish_reason else None + ), + } + + # Send result + await state.out_queue.put(result) + + # Mark as finished + state.finished = True + state.finished_time = time.time() + state.event.set() + + async def _handle_health_check_output(self, health_out: HealthCheckOutput): + """Handle health check output from scheduler.""" + rid = health_out.rid + + if rid not in self.rid_to_state: + logger.warning(f"Health check output for unknown request: {rid}") + return + + state = self.rid_to_state[rid] + + # Create health check result + result = { + "request_id": rid, + "healthy": True, # If we got a response, scheduler is healthy + "output_text": ( + health_out.output_str if hasattr(health_out, "output_str") else "" + ), + "finish_reason": ( + health_out.finish_reason + if hasattr(health_out, "finish_reason") + else "stop" + ), + } + + # Send result + await state.out_queue.put(result) + + # Mark as finished + state.finished = True + state.finished_time = time.time() + state.event.set() + + async def _send_to_scheduler(self, obj): + """Send an object to the scheduler via ZMQ.""" + try: + self.send_to_scheduler.send_pyobj(obj) + except Exception as e: + logger.error(f"Failed to send to scheduler: {e}") + raise + + def record_request_for_crash_dump(self, obj): + """Record request for potential crash dump.""" + if len(self.crash_dump_request_list) < 100: + self.crash_dump_request_list.append( + { + "time": time.time(), + "request_id": getattr(obj, "rid", "unknown"), + "type": type(obj).__name__, + } + ) + + async def shutdown(self): + """Gracefully shutdown the request manager.""" + logger.info("Shutting down GrpcRequestManager") + self.gracefully_exit = True + + # Cancel all pending requests + for rid, state in self.rid_to_state.items(): + if not state.finished: + await state.out_queue.put( + {"error": "Server shutting down", "shutdown": True} + ) + state.finished = True + state.event.set() + + # Wait for tasks to complete + if self.asyncio_tasks: + await asyncio.gather(*list(self.asyncio_tasks), return_exceptions=True) + + # Close ZMQ sockets + self.recv_from_scheduler.close() + self.send_to_scheduler.close() + + logger.info("GrpcRequestManager shutdown complete") + + def get_server_info(self) -> Dict[str, Any]: + """Get server information for health checks.""" + return { + "active_requests": len(self.rid_to_state), + "paused": self.is_pause, + "last_receive_time": self.last_receive_tstamp, + } + + def auto_create_handle_loop(self): + """Automatically create and start the handle_loop task, matching TokenizerManager pattern.""" + if self.no_create_loop: + return + + self.no_create_loop = True + loop = asyncio.get_event_loop() + self.asyncio_tasks.add( + loop.create_task(print_exception_wrapper(self.handle_loop)) + ) + + self.event_loop = loop + + # We cannot add signal handler when the grpc manager is not in + # the main thread due to the CPython limitation. + if threading.current_thread() is threading.main_thread(): + signal_handler = GrpcSignalHandler(self) + loop.add_signal_handler(signal.SIGTERM, signal_handler.sigterm_handler) + # Update the signal handler for the process. It overrides the sigquit handler in the launch phase. + loop.add_signal_handler( + signal.SIGQUIT, signal_handler.running_phase_sigquit_handler + ) + else: + logger.warning( + "Signal handler is not added because the grpc request manager is " + "not in the main thread. This disables graceful shutdown of the " + "grpc request manager when SIGTERM is received." + ) + self.asyncio_tasks.add( + loop.create_task(print_exception_wrapper(self.sigterm_watchdog)) + ) + + async def sigterm_watchdog(self): + """Watchdog to handle SIGTERM gracefully, matching TokenizerManager pattern.""" + while not self.gracefully_exit: + await asyncio.sleep(1.0) + + +async def print_exception_wrapper(func): + """ + Sometimes an asyncio function does not print exception. + We do another wrapper to handle the exception. + """ + try: + await func() + except Exception: + traceback = get_exception_traceback() + logger.error(f"GrpcRequestManager hit an exception: {traceback}") + if hasattr(func, "__self__") and isinstance(func.__self__, GrpcRequestManager): + func.__self__.dump_requests_before_crash() + kill_process_tree(os.getpid(), include_parent=True) + sys.exit(1) diff --git a/python/sglang/srt/entrypoints/grpc_server.py b/python/sglang/srt/entrypoints/grpc_server.py new file mode 100644 index 000000000..f7edf7743 --- /dev/null +++ b/python/sglang/srt/entrypoints/grpc_server.py @@ -0,0 +1,680 @@ +""" +Standalone gRPC Server for SGLang - Fully separated from HTTP server. +Uses GrpcRequestManager for orchestration without tokenization. +""" + +import argparse +import asyncio +import logging +import multiprocessing as mp +import os +import signal +import time +from concurrent import futures +from typing import AsyncIterator, Dict, Optional, Tuple + +import grpc +from grpc_reflection.v1alpha import reflection + +from sglang.srt.entrypoints.grpc_request_manager import GrpcRequestManager +from sglang.srt.grpc import sglang_scheduler_pb2, sglang_scheduler_pb2_grpc +from sglang.srt.managers.data_parallel_controller import ( + run_data_parallel_controller_process, +) +from sglang.srt.managers.io_struct import ( + TokenizedEmbeddingReqInput, + TokenizedGenerateReqInput, +) +from sglang.srt.managers.scheduler import run_scheduler_process +from sglang.srt.sampling.sampling_params import SamplingParams as SGLSamplingParams +from sglang.srt.server_args import PortArgs, ServerArgs +from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter +from sglang.srt.utils import configure_logger, prepare_model_and_tokenizer +from sglang.utils import get_exception_traceback + +logger = logging.getLogger(__name__) +HEALTH_CHECK_TIMEOUT = int(os.getenv("SGLANG_HEALTH_CHECK_TIMEOUT", 20)) + + +def _launch_scheduler_process_only( + server_args: ServerArgs, + port_args: Optional[PortArgs] = None, +) -> Tuple[Dict, PortArgs, list]: + """ + Launch only the scheduler process(es) without tokenizer/detokenizer. + Returns scheduler info, port args, and list of scheduler processes. + """ + # Configure global environment + configure_logger(server_args) + server_args.check_server_args() + + # Allocate ports for inter-process communications + if port_args is None: + port_args = PortArgs.init_new(server_args) + logger.info(f"{server_args=}") + + # Prepare model and tokenizer paths + server_args.model_path, server_args.tokenizer_path = prepare_model_and_tokenizer( + server_args.model_path, server_args.tokenizer_path + ) + + scheduler_procs = [] + if server_args.dp_size == 1: + memory_saver_adapter = TorchMemorySaverAdapter.create( + enable=server_args.enable_memory_saver + ) + scheduler_pipe_readers = [] + + nnodes_per_tp_group = max(server_args.nnodes // server_args.pp_size, 1) + tp_size_per_node = server_args.tp_size // nnodes_per_tp_group + tp_rank_range = range( + tp_size_per_node * (server_args.node_rank % nnodes_per_tp_group), + tp_size_per_node * (server_args.node_rank % nnodes_per_tp_group + 1), + ) + + pp_size_per_node = max(server_args.pp_size // server_args.nnodes, 1) + pp_rank_range = range( + pp_size_per_node * (server_args.node_rank // nnodes_per_tp_group), + pp_size_per_node * (server_args.node_rank // nnodes_per_tp_group + 1), + ) + + for pp_rank in pp_rank_range: + for tp_rank in tp_rank_range: + reader, writer = mp.Pipe(duplex=False) + gpu_id = ( + server_args.base_gpu_id + + ((pp_rank % pp_size_per_node) * tp_size_per_node) + + (tp_rank % tp_size_per_node) * server_args.gpu_id_step + ) + moe_ep_rank = tp_rank // (server_args.tp_size // server_args.ep_size) + proc = mp.Process( + target=run_scheduler_process, + args=( + server_args, + port_args, + gpu_id, + tp_rank, + moe_ep_rank, + pp_rank, + None, + writer, + None, + ), + ) + + with memory_saver_adapter.configure_subprocess(): + proc.start() + scheduler_procs.append(proc) + scheduler_pipe_readers.append(reader) + else: + # Launch the data parallel controller + reader, writer = mp.Pipe(duplex=False) + scheduler_pipe_readers = [reader] + proc = mp.Process( + target=run_data_parallel_controller_process, + args=(server_args, port_args, writer), + ) + proc.start() + scheduler_procs.append(proc) + + # TODO(CatherineSue): handle cases for multi-node + + # Wait for all scheduler processes to be ready + scheduler_infos = [] + for i, reader in enumerate(scheduler_pipe_readers): + try: + data = reader.recv() + except EOFError: + logger.error( + f"Rank {i} scheduler is dead. Please check if there are relevant logs." + ) + scheduler_procs[i].join() + logger.error(f"Exit code: {scheduler_procs[i].exitcode}") + raise RuntimeError(f"Failed to initialize scheduler rank {i}") + + if data.get("status") != "ready": + raise RuntimeError( + f"Scheduler rank {i} initialization failed: {data.get('error', 'Unknown error')}" + ) + scheduler_infos.append(data) + + logger.info( + f"All {len(scheduler_procs)} scheduler process(es) initialized successfully" + ) + + # Return the first scheduler's info (they should all be the same) + return scheduler_infos[0], port_args, scheduler_procs + + +class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer): + """ + Standalone gRPC service implementation using GrpcRequestManager. + Fully separated from HTTP server with its own process and no shared globals. + """ + + def __init__( + self, + request_manager: GrpcRequestManager, + server_args: ServerArgs, + model_info: Dict, + ): + """Initialize the standalone gRPC service.""" + self.request_manager = request_manager + self.server_args = server_args + self.model_info = model_info + self.start_time = time.time() + + # Start the request manager's event loop using auto_create_handle_loop + self.request_manager.auto_create_handle_loop() + + logger.info("Standalone gRPC scheduler service initialized") + + async def Generate( + self, + request: sglang_scheduler_pb2.GenerateRequest, + context: grpc.aio.ServicerContext, + ) -> AsyncIterator[sglang_scheduler_pb2.GenerateResponse]: + """Handle generation requests with streaming responses.""" + logger.info(f"Generation request: {request.request_id}") + + try: + # Convert gRPC request to internal format + tokenized_req = self._convert_generate_request(request) + + # Submit to request manager + output_queue = await self.request_manager.generate_request( + obj=tokenized_req, + request_id=request.request_id, + grpc_context=context, + ) + + # Stream outputs + while True: + try: + # Get output with timeout + output = await asyncio.wait_for(output_queue.get(), timeout=4) + + # Check for errors + if "error" in output: + yield sglang_scheduler_pb2.GenerateResponse( + request_id=request.request_id, + error=sglang_scheduler_pb2.GenerateError( + message=output["error"], + http_status_code=( + "500" if "abort" not in output else "499" + ), + ), + ) + break + + # Check if finished + if output.get("finished", False): + # Send completion + yield self._create_completion_response( + request.request_id, output + ) + break + else: + # Send chunk + yield self._create_chunk_response(request.request_id, output) + + except asyncio.TimeoutError: + # Check if context is still active + if context.cancelled(): + # Abort the request + await self.request_manager.abort_request(request.request_id) + break + continue + + except Exception as e: + logger.error(f"Generate failed: {e}\n{get_exception_traceback()}") + yield sglang_scheduler_pb2.GenerateResponse( + request_id=request.request_id, + error=sglang_scheduler_pb2.GenerateError( + message=str(e), + http_status_code="500", + details=get_exception_traceback(), + ), + ) + + async def Embed( + self, + request: sglang_scheduler_pb2.EmbedRequest, + context: grpc.aio.ServicerContext, + ) -> sglang_scheduler_pb2.EmbedResponse: + """Handle embedding requests.""" + logger.info(f"Embedding request: {request.request_id}") + + try: + # Convert request + tokenized_req = self._convert_embed_request(request) + + # Submit to request manager + future = await self.request_manager.embedding_request( + obj=tokenized_req, + request_id=request.request_id, + ) + + # Wait for result + result = await future + + # Create response + return sglang_scheduler_pb2.EmbedResponse( + request_id=request.request_id, + complete=sglang_scheduler_pb2.EmbedComplete( + embedding=result["embedding"], + prompt_tokens=result.get("prompt_tokens", 0), + cached_tokens=0, + embedding_dim=len(result["embedding"]), + generation_time=time.time() - self.start_time, + ), + ) + + except Exception as e: + logger.error(f"Embed failed: {e}\n{get_exception_traceback()}") + return sglang_scheduler_pb2.EmbedResponse( + request_id=request.request_id, + error=sglang_scheduler_pb2.EmbedError( + message=str(e), + code="INTERNAL_ERROR", + details=get_exception_traceback(), + ), + ) + + async def HealthCheck( + self, + request: sglang_scheduler_pb2.HealthCheckRequest, + context: grpc.aio.ServicerContext, + ) -> sglang_scheduler_pb2.HealthCheckResponse: + """Health check by generating from client input.""" + try: + # Check if request manager is shutting down + if self.request_manager.gracefully_exit: + return sglang_scheduler_pb2.HealthCheckResponse( + healthy=False, message="Server shutting down" + ) + + # Extract tokenized input from request + if not request.HasField("tokenized"): + return sglang_scheduler_pb2.HealthCheckResponse( + healthy=False, message="Tokenized input required for health check" + ) + + input_text = request.tokenized.original_text + input_ids = list(request.tokenized.input_ids) + + # Create health check request + rid = f"HEALTH_CHECK_GRPC_{time.time()}" + + health_request = TokenizedGenerateReqInput( + rid=rid, + input_text=input_text, + input_ids=input_ids, + sampling_params=SGLSamplingParams(max_new_tokens=1, temperature=0.0), + stream=False, + mm_inputs=None, + return_logprob=False, + logprob_start_len=-1, + top_logprobs_num=0, + token_ids_logprob=None, + ) + + logger.info(f"Sending health check request to request manager...") + + # Submit and wait for response + output_queue = await self.request_manager.generate_request( + health_request, request_id=rid + ) + + try: + # Wait for response with configurable timeout + response = await asyncio.wait_for( + output_queue.get(), timeout=HEALTH_CHECK_TIMEOUT + ) + + # Clean up + if rid in self.request_manager.rid_to_state: + del self.request_manager.rid_to_state[rid] + + return sglang_scheduler_pb2.HealthCheckResponse( + healthy=True, message="Health check passed" + ) + + except asyncio.TimeoutError: + # Clean up on timeout + if rid in self.request_manager.rid_to_state: + del self.request_manager.rid_to_state[rid] + + return sglang_scheduler_pb2.HealthCheckResponse( + healthy=False, message="Health check timeout" + ) + + except Exception as e: + logger.error(f"Health check failed: {e}") + return sglang_scheduler_pb2.HealthCheckResponse( + healthy=False, message=f"Health check error: {str(e)}" + ) + + async def Abort( + self, + request: sglang_scheduler_pb2.AbortRequest, + context: grpc.aio.ServicerContext, + ) -> sglang_scheduler_pb2.AbortResponse: + """Abort an ongoing request.""" + logger.info(f"Aborting request: {request.request_id}") + + try: + success = await self.request_manager.abort_request(request.request_id) + + return sglang_scheduler_pb2.AbortResponse( + success=success, + message=f"Request {request.request_id} {'aborted' if success else 'not found'}", + ) + except Exception as e: + logger.error(f"Abort failed: {e}") + return sglang_scheduler_pb2.AbortResponse( + success=False, + message=str(e), + ) + + # Helper methods for request/response conversion + + def _convert_generate_request( + self, grpc_req: sglang_scheduler_pb2.GenerateRequest + ) -> TokenizedGenerateReqInput: + """Convert gRPC GenerateRequest to internal format.""" + + # Extract tokenized input + if not grpc_req.HasField("tokenized"): + raise ValueError("Tokenized input must be provided") + + input_text = grpc_req.tokenized.original_text + input_ids = list(grpc_req.tokenized.input_ids) + + # Convert sampling params + sampling_params = self._convert_sampling_params(grpc_req.sampling_params) + + # Create request + return TokenizedGenerateReqInput( + rid=grpc_req.request_id, + input_text=input_text, + input_ids=input_ids, + mm_inputs=None, # TODO: implement mm support + sampling_params=sampling_params, + return_logprob=grpc_req.return_logprob, + logprob_start_len=grpc_req.logprob_start_len or -1, + top_logprobs_num=grpc_req.top_logprobs_num or 0, + stream=True, # Always stream for gRPC + lora_path=grpc_req.lora_id if grpc_req.lora_id else None, + token_ids_logprob=( + list(grpc_req.token_ids_logprob) if grpc_req.token_ids_logprob else None + ), + ) + + def _convert_embed_request( + self, grpc_req: sglang_scheduler_pb2.EmbedRequest + ) -> TokenizedEmbeddingReqInput: + """Convert gRPC EmbedRequest to internal format.""" + + # Extract tokenized input + if not grpc_req.HasField("tokenized"): + raise ValueError("Tokenized input must be provided") + + input_text = grpc_req.tokenized.original_text + input_ids = list(grpc_req.tokenized.input_ids) + + return TokenizedEmbeddingReqInput( + rid=grpc_req.request_id, + input_text=input_text, + input_ids=input_ids, + ) + + def _convert_sampling_params( + self, grpc_params: sglang_scheduler_pb2.SamplingParams + ) -> SGLSamplingParams: + """Convert gRPC SamplingParams to internal format.""" + + # Handle constraint types + regex = None + json_schema = None + ebnf_grammar = None + + if grpc_params.HasField("regex"): + regex = grpc_params.regex + elif grpc_params.HasField("json_schema"): + json_schema = grpc_params.json_schema + elif grpc_params.HasField("ebnf_grammar"): + ebnf_grammar = grpc_params.ebnf_grammar + + return SGLSamplingParams( + temperature=grpc_params.temperature or 1.0, + top_p=grpc_params.top_p or 1.0, + top_k=grpc_params.top_k or -1, + min_p=grpc_params.min_p or 0.0, + frequency_penalty=grpc_params.frequency_penalty or 0.0, + presence_penalty=grpc_params.presence_penalty or 0.0, + repetition_penalty=grpc_params.repetition_penalty or 1.0, + max_new_tokens=grpc_params.max_new_tokens or 128, + min_new_tokens=grpc_params.min_new_tokens or 0, + stop=list(grpc_params.stop) if grpc_params.stop else None, + stop_token_ids=( + list(grpc_params.stop_token_ids) if grpc_params.stop_token_ids else None + ), + skip_special_tokens=grpc_params.skip_special_tokens, + spaces_between_special_tokens=grpc_params.spaces_between_special_tokens, + regex=regex, + json_schema=json_schema, + ebnf=ebnf_grammar, + n=grpc_params.n or 1, + ignore_eos=grpc_params.ignore_eos, + ) + + def _create_chunk_response( + self, request_id: str, output: Dict + ) -> sglang_scheduler_pb2.GenerateResponse: + """Create a streaming chunk response.""" + return sglang_scheduler_pb2.GenerateResponse( + request_id=request_id, + chunk=sglang_scheduler_pb2.GenerateStreamChunk( + token_id=output["token_ids"][-1] if output.get("token_ids") else 0, + text=output.get("text", ""), + prompt_tokens=0, + completion_tokens=len(output.get("token_ids", [])), + cached_tokens=0, + generation_time=time.time() - self.start_time, + queue_time=0.0, + ), + ) + + def _create_completion_response( + self, request_id: str, output: Dict + ) -> sglang_scheduler_pb2.GenerateResponse: + """Create a completion response.""" + + # Determine finish reason + finish_reason = sglang_scheduler_pb2.GenerateComplete.STOP + meta_info = output.get("meta_info", {}) + if meta_info.get("finish_reason") == "length": + finish_reason = sglang_scheduler_pb2.GenerateComplete.LENGTH + elif meta_info.get("finish_reason") == "eos_token": + finish_reason = sglang_scheduler_pb2.GenerateComplete.EOS_TOKEN + + return sglang_scheduler_pb2.GenerateResponse( + request_id=request_id, + complete=sglang_scheduler_pb2.GenerateComplete( + output_ids=output.get("token_ids", []), + output_text=output.get("text", ""), + finish_reason=finish_reason, + ), + ) + + async def shutdown(self): + """Shutdown the service.""" + logger.info("Shutting down gRPC service") + + # Shutdown request manager (handles its own tasks) + await self.request_manager.shutdown() + + +async def serve_grpc( + server_args: ServerArgs, + model_info: Optional[Dict] = None, +): + """Start the standalone gRPC server with integrated scheduler.""" + + # Launch only the scheduler process(es) (no tokenizer/detokenizer needed for gRPC) + logger.info("Launching scheduler process(es)...") + scheduler_info, port_args, scheduler_procs = _launch_scheduler_process_only( + server_args=server_args, + ) + + # Update model info from scheduler info + if model_info is None: + model_info = { + "model_name": server_args.model_path, + "max_context_length": scheduler_info.get( + "max_total_num_tokens", server_args.context_length or 8192 + ), + "vocab_size": scheduler_info.get("vocab_size", 128256), + "supports_vision": scheduler_info.get("supports_vision", False), + "model_type": scheduler_info.get("model_type", "transformer"), + "max_req_input_len": scheduler_info.get("max_req_input_len", 8192), + "eos_token_ids": scheduler_info.get("eos_token_ids", []), + "pad_token_id": scheduler_info.get("pad_token_id", 0), + "bos_token_id": scheduler_info.get("bos_token_id", 1), + } + + # Create request manager with the correct port args + request_manager = GrpcRequestManager( + server_args=server_args, + port_args=port_args, + ) + + # Create gRPC server + server = grpc.aio.server( + futures.ThreadPoolExecutor(max_workers=10), + options=[ + ("grpc.max_send_message_length", 1024 * 1024 * 256), + ("grpc.max_receive_message_length", 1024 * 1024 * 256), + ], + ) + + # Add service + servicer = SGLangSchedulerServicer( + request_manager=request_manager, + server_args=server_args, + model_info=model_info, + ) + sglang_scheduler_pb2_grpc.add_SglangSchedulerServicer_to_server(servicer, server) + + # Enable reflection + SERVICE_NAMES = ( + sglang_scheduler_pb2.DESCRIPTOR.services_by_name["SglangScheduler"].full_name, + reflection.SERVICE_NAME, + ) + reflection.enable_server_reflection(SERVICE_NAMES, server) + + # Start server + listen_addr = f"{server_args.host}:{server_args.port}" + server.add_insecure_port(listen_addr) + + logger.info(f"Starting standalone gRPC server on {listen_addr}") + + await server.start() + + # Handle shutdown signals + loop = asyncio.get_running_loop() + stop_event = asyncio.Event() + + def signal_handler(): + logger.info("Received shutdown signal") + stop_event.set() + + for sig in (signal.SIGTERM, signal.SIGINT): + loop.add_signal_handler(sig, signal_handler) + + try: + await stop_event.wait() + finally: + logger.info("Shutting down gRPC server") + await servicer.shutdown() + await server.stop(5.0) + + # Terminate scheduler processes + for i, proc in enumerate(scheduler_procs): + if proc and proc.is_alive(): + logger.info(f"Terminating scheduler process {i}...") + proc.terminate() + proc.join(timeout=5.0) + if proc.is_alive(): + logger.warning(f"Force killing scheduler process {i}...") + proc.kill() + proc.join() + + +def main(): + """Main entry point for standalone gRPC server.""" + # Fix CUDA multiprocessing issues - must be called before any CUDA operations + mp.set_start_method("spawn", force=True) + + parser = argparse.ArgumentParser(description="SGLang Standalone gRPC Server") + + # Server arguments + parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to bind to") + parser.add_argument("--port", type=int, default=30000, help="gRPC server port") + + # Model arguments + parser.add_argument("--model-path", type=str, required=True, help="Model path") + parser.add_argument("--tokenizer-path", type=str, help="Tokenizer path") + parser.add_argument("--context-length", type=int, help="Context length") + parser.add_argument("--tp-size", type=int, default=1, help="Tensor parallel size") + parser.add_argument("--dp-size", type=int, default=1, help="Data parallel size") + + # Runtime arguments + parser.add_argument( + "--max-running-requests", type=int, default=2048, help="Max concurrent requests" + ) + parser.add_argument( + "--max-total-tokens", type=int, default=1000000, help="Max total tokens" + ) + parser.add_argument( + "--max-prefill-tokens", type=int, default=16384, help="Max prefill tokens" + ) + parser.add_argument( + "--attention-backend", type=str, default="flashinfer", help="Attention backend" + ) + parser.add_argument("--lora-paths", type=str, help="LoRA adapter paths") + + # Logging + parser.add_argument("--log-level", type=str, default="INFO", help="Logging level") + + args = parser.parse_args() + + # Convert to ServerArgs with gRPC host and port + server_args = ServerArgs( + model_path=args.model_path, + tokenizer_path=args.tokenizer_path or args.model_path, + context_length=args.context_length, + tp_size=args.tp_size, + dp_size=args.dp_size, + max_running_requests=args.max_running_requests, + max_total_tokens=args.max_total_tokens, + max_prefill_tokens=args.max_prefill_tokens, + attention_backend=args.attention_backend, + lora_paths=args.lora_paths.split(",") if args.lora_paths else None, + log_level=args.log_level, + # Override with gRPC server host and port + host=args.host, + port=args.port, + ) + + # Run server + asyncio.run( + serve_grpc( + server_args=server_args, + ) + ) + + +if __name__ == "__main__": + main() diff --git a/python/sglang/srt/grpc/__init__.py b/python/sglang/srt/grpc/__init__.py new file mode 100644 index 000000000..de1d8e32a --- /dev/null +++ b/python/sglang/srt/grpc/__init__.py @@ -0,0 +1 @@ +# SGLang gRPC module diff --git a/python/sglang/srt/grpc/sglang_scheduler.proto b/python/sglang/srt/grpc/sglang_scheduler.proto new file mode 100644 index 000000000..e4c87925e --- /dev/null +++ b/python/sglang/srt/grpc/sglang_scheduler.proto @@ -0,0 +1,389 @@ +syntax = "proto3"; + +package sglang.grpc.scheduler; + +import "google/protobuf/timestamp.proto"; +import "google/protobuf/struct.proto"; + +// Service definition for SGLang scheduler communication +// This protocol bridges the Rust router and Python scheduler +service SglangScheduler { + // Submit a generation request (supports streaming) + rpc Generate(GenerateRequest) returns (stream GenerateResponse); + + // Submit an embedding request + rpc Embed(EmbedRequest) returns (EmbedResponse); + + // Health check and metrics + rpc HealthCheck(HealthCheckRequest) returns (HealthCheckResponse); + + // Abort a running request + rpc Abort(AbortRequest) returns (AbortResponse); + +} + +// ===================== +// Common Types +// ===================== + +// Sampling parameters matching SGLang's SamplingParams +message SamplingParams { + float temperature = 1; + float top_p = 2; + int32 top_k = 3; + float min_p = 4; + float frequency_penalty = 5; + float presence_penalty = 6; + float repetition_penalty = 7; + + int32 max_new_tokens = 8; + repeated string stop = 9; + repeated int32 stop_token_ids = 10; + bool skip_special_tokens = 11; + bool spaces_between_special_tokens = 12; + + // Structured generation + oneof constraint { + string regex = 13; + string json_schema = 14; + string ebnf_grammar = 15; + } + + // LoRA adapter + string lora_path = 16; + + // Speculative decoding + int32 n = 17; // Number of samples + + // Token healing + bool token_healing = 18; + + // Additional parameters + int32 min_new_tokens = 19; + bool ignore_eos = 20; + bool no_stop_trim = 21; + int32 stream_interval = 22; + map logit_bias = 23; + string structural_tag = 24; + + // Custom parameters for extensibility + google.protobuf.Struct custom_params = 25; +} + + +// Disaggregated serving parameters +message DisaggregatedParams { + string bootstrap_host = 1; + int32 bootstrap_port = 2; + int32 bootstrap_room = 3; +} + +// ===================== +// Generate Request +// ===================== + +message GenerateRequest { + string request_id = 1; + + // Input must be tokenized (no raw text) + TokenizedInput tokenized = 2; + + // Multimodal inputs + MultimodalInputs mm_inputs = 3; + + // Generation parameters + SamplingParams sampling_params = 4; + + // Return options + bool return_logprob = 5; + int32 logprob_start_len = 6; + int32 top_logprobs_num = 7; + repeated int32 token_ids_logprob = 8; + bool return_hidden_states = 9; + + // For disaggregated serving + DisaggregatedParams disaggregated_params = 10; + + // Custom logit processor (serialized) + string custom_logit_processor = 11; + + // Request metadata + google.protobuf.Timestamp timestamp = 12; + bool log_metrics = 13; + + // Input embeddings (alternative to text/tokens) + repeated float input_embeds = 14; + + // LoRA adapter ID (if pre-loaded) + string lora_id = 15; + + // Data parallel routing + int32 data_parallel_rank = 16; + + // For load balancing + int32 dp_balance_id = 17; +} + +message TokenizedInput { + string original_text = 1; // For reference + repeated int32 input_ids = 2; +} + +message MultimodalInputs { + // Simplified multimodal handling - actual data processed by tokenizer + repeated string image_urls = 1; + repeated string video_urls = 2; + repeated string audio_urls = 3; + + // Pre-processed multimodal features (if available) + google.protobuf.Struct processed_features = 4; + + // Raw data for direct processing + repeated bytes image_data = 5; + repeated bytes video_data = 6; + repeated bytes audio_data = 7; + + // Modality metadata + repeated string modalities = 8; +} + +// ===================== +// Generate Response +// ===================== + +message GenerateResponse { + string request_id = 1; + + // Response type + oneof response { + GenerateStreamChunk chunk = 2; + GenerateComplete complete = 3; + GenerateError error = 4; + } +} + +message GenerateStreamChunk { + // Generated token + int32 token_id = 1; + string text = 2; + + // Cumulative counts + int32 prompt_tokens = 3; + int32 completion_tokens = 4; + int32 cached_tokens = 5; + + // Logprobs (if requested) + LogProbs logprobs = 6; + + // Hidden states (if requested) + repeated float hidden_states = 7; + + // Metadata + float generation_time = 8; // Time to generate this token + int32 queue_time = 9; // Time spent in queue +} + +message GenerateComplete { + // Final output + repeated int32 output_ids = 1; + string output_text = 2; + + // Finish reason + enum FinishReason { + // The model generated a stop sequence. + STOP = 0; + // The model reached the maximum generation length. + LENGTH = 1; + // The model generated an end-of-sequence (EOS) token. + EOS_TOKEN = 2; + // The model generated a user-provided stop string. + STOP_STR = 3; + // The request was aborted by the user or system. + ABORT = 4; + } + FinishReason finish_reason = 3; + + // All logprobs if requested + repeated LogProbs all_logprobs = 11; + + // All hidden states if requested + repeated HiddenStates all_hidden_states = 12; +} + +message GenerateError { + string message = 1; + string http_status_code = 2; + string details = 3; +} + +message LogProbs { + repeated float token_logprobs = 1; + repeated int32 token_ids = 2; + + // Top logprobs at each position + repeated TopLogProbs top_logprobs = 3; + + // Decoded text for tokens + repeated string token_texts = 4; +} + +message TopLogProbs { + repeated float values = 1; + repeated int32 token_ids = 2; + repeated string token_texts = 3; +} + +message HiddenStates { + repeated float values = 1; + int32 layer = 2; + int32 position = 3; +} + +// ===================== +// Embedding Request +// ===================== + +message EmbedRequest { + string request_id = 1; + + // Input must be tokenized (no raw text) + TokenizedInput tokenized = 2; + + // Multimodal inputs + MultimodalInputs mm_inputs = 4; + + // Dummy sampling params for compatibility + // EmbedRequest doesn't use sampling_params + SamplingParams sampling_params = 5; + + bool log_metrics = 6; + + // Token type IDs for models that require them + repeated int32 token_type_ids = 7; + + // Data parallel routing + int32 data_parallel_rank = 8; + + // For cross-encoder requests + bool is_cross_encoder = 9; + repeated string texts = 10; // For cross-encoder batch +} + +message EmbedResponse { + string request_id = 1; + + oneof response { + EmbedComplete complete = 2; + EmbedError error = 3; + } +} + +message EmbedComplete { + repeated float embedding = 1; + int32 prompt_tokens = 2; + int32 cached_tokens = 3; + + // Additional metadata + int32 embedding_dim = 4; + float generation_time = 5; + + // For batch embeddings + repeated Embedding batch_embeddings = 6; +} + +message Embedding { + repeated float values = 1; + int32 index = 2; +} + +message EmbedError { + string message = 1; + string code = 2; + string details = 3; +} + +// ===================== +// Management Operations +// ===================== + +message HealthCheckRequest { + // Input for health test generation (must be tokenized) + TokenizedInput tokenized = 1; +} + +message HealthCheckResponse { + bool healthy = 1; + string message = 2; +} + +message AbortRequest { + string request_id = 1; + string reason = 2; +} + +message AbortResponse { + bool success = 1; + string message = 2; +} + + +// ===================== +// Additional Operations (Future) +// ===================== + +// Load LoRA adapter +message LoadLoRARequest { + string adapter_id = 1; + string adapter_path = 2; + int32 rank = 3; +} + +message LoadLoRAResponse { + bool success = 1; + string adapter_id = 2; + string message = 3; +} + +// Unload LoRA adapter +message UnloadLoRARequest { + string adapter_id = 1; +} + +message UnloadLoRAResponse { + bool success = 1; + string message = 2; +} + +// Update weights +message UpdateWeightsRequest { + oneof source { + string disk_path = 1; + bytes tensor_data = 2; + string remote_url = 3; + } + string weight_name = 4; +} + +message UpdateWeightsResponse { + bool success = 1; + string message = 2; +} + +// Get internal state for debugging +message GetInternalStateRequest { + repeated string state_keys = 1; +} + +message GetInternalStateResponse { + google.protobuf.Struct state = 1; +} + +// Set internal state for testing +message SetInternalStateRequest { + google.protobuf.Struct state = 1; +} + +message SetInternalStateResponse { + bool success = 1; + string message = 2; +} diff --git a/python/sglang/srt/grpc/sglang_scheduler_pb2.py b/python/sglang/srt/grpc/sglang_scheduler_pb2.py new file mode 100644 index 000000000..4b288d768 --- /dev/null +++ b/python/sglang/srt/grpc/sglang_scheduler_pb2.py @@ -0,0 +1,106 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# NO CHECKED-IN PROTOBUF GENCODE +# source: sglang_scheduler.proto +# Protobuf Python Version: 6.31.1 +"""Generated protocol buffer code.""" +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import runtime_version as _runtime_version +from google.protobuf import symbol_database as _symbol_database +from google.protobuf.internal import builder as _builder +_runtime_version.ValidateProtobufRuntimeVersion( + _runtime_version.Domain.PUBLIC, + 6, + 31, + 1, + '', + 'sglang_scheduler.proto' +) +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + +from google.protobuf import timestamp_pb2 as google_dot_protobuf_dot_timestamp__pb2 +from google.protobuf import struct_pb2 as google_dot_protobuf_dot_struct__pb2 + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x16sglang_scheduler.proto\x12\x15sglang.grpc.scheduler\x1a\x1fgoogle/protobuf/timestamp.proto\x1a\x1cgoogle/protobuf/struct.proto\"\xc7\x05\n\x0eSamplingParams\x12\x13\n\x0btemperature\x18\x01 \x01(\x02\x12\r\n\x05top_p\x18\x02 \x01(\x02\x12\r\n\x05top_k\x18\x03 \x01(\x05\x12\r\n\x05min_p\x18\x04 \x01(\x02\x12\x19\n\x11\x66requency_penalty\x18\x05 \x01(\x02\x12\x18\n\x10presence_penalty\x18\x06 \x01(\x02\x12\x1a\n\x12repetition_penalty\x18\x07 \x01(\x02\x12\x16\n\x0emax_new_tokens\x18\x08 \x01(\x05\x12\x0c\n\x04stop\x18\t \x03(\t\x12\x16\n\x0estop_token_ids\x18\n \x03(\x05\x12\x1b\n\x13skip_special_tokens\x18\x0b \x01(\x08\x12%\n\x1dspaces_between_special_tokens\x18\x0c \x01(\x08\x12\x0f\n\x05regex\x18\r \x01(\tH\x00\x12\x15\n\x0bjson_schema\x18\x0e \x01(\tH\x00\x12\x16\n\x0c\x65\x62nf_grammar\x18\x0f \x01(\tH\x00\x12\x11\n\tlora_path\x18\x10 \x01(\t\x12\t\n\x01n\x18\x11 \x01(\x05\x12\x15\n\rtoken_healing\x18\x12 \x01(\x08\x12\x16\n\x0emin_new_tokens\x18\x13 \x01(\x05\x12\x12\n\nignore_eos\x18\x14 \x01(\x08\x12\x14\n\x0cno_stop_trim\x18\x15 \x01(\x08\x12\x17\n\x0fstream_interval\x18\x16 \x01(\x05\x12H\n\nlogit_bias\x18\x17 \x03(\x0b\x32\x34.sglang.grpc.scheduler.SamplingParams.LogitBiasEntry\x12\x16\n\x0estructural_tag\x18\x18 \x01(\t\x12.\n\rcustom_params\x18\x19 \x01(\x0b\x32\x17.google.protobuf.Struct\x1a\x30\n\x0eLogitBiasEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\x02:\x02\x38\x01\x42\x0c\n\nconstraint\"]\n\x13\x44isaggregatedParams\x12\x16\n\x0e\x62ootstrap_host\x18\x01 \x01(\t\x12\x16\n\x0e\x62ootstrap_port\x18\x02 \x01(\x05\x12\x16\n\x0e\x62ootstrap_room\x18\x03 \x01(\x05\"\xe9\x04\n\x0fGenerateRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x38\n\ttokenized\x18\x02 \x01(\x0b\x32%.sglang.grpc.scheduler.TokenizedInput\x12:\n\tmm_inputs\x18\x03 \x01(\x0b\x32\'.sglang.grpc.scheduler.MultimodalInputs\x12>\n\x0fsampling_params\x18\x04 \x01(\x0b\x32%.sglang.grpc.scheduler.SamplingParams\x12\x16\n\x0ereturn_logprob\x18\x05 \x01(\x08\x12\x19\n\x11logprob_start_len\x18\x06 \x01(\x05\x12\x18\n\x10top_logprobs_num\x18\x07 \x01(\x05\x12\x19\n\x11token_ids_logprob\x18\x08 \x03(\x05\x12\x1c\n\x14return_hidden_states\x18\t \x01(\x08\x12H\n\x14\x64isaggregated_params\x18\n \x01(\x0b\x32*.sglang.grpc.scheduler.DisaggregatedParams\x12\x1e\n\x16\x63ustom_logit_processor\x18\x0b \x01(\t\x12-\n\ttimestamp\x18\x0c \x01(\x0b\x32\x1a.google.protobuf.Timestamp\x12\x13\n\x0blog_metrics\x18\r \x01(\x08\x12\x14\n\x0cinput_embeds\x18\x0e \x03(\x02\x12\x0f\n\x07lora_id\x18\x0f \x01(\t\x12\x1a\n\x12\x64\x61ta_parallel_rank\x18\x10 \x01(\x05\x12\x15\n\rdp_balance_id\x18\x11 \x01(\x05\":\n\x0eTokenizedInput\x12\x15\n\roriginal_text\x18\x01 \x01(\t\x12\x11\n\tinput_ids\x18\x02 \x03(\x05\"\xd3\x01\n\x10MultimodalInputs\x12\x12\n\nimage_urls\x18\x01 \x03(\t\x12\x12\n\nvideo_urls\x18\x02 \x03(\t\x12\x12\n\naudio_urls\x18\x03 \x03(\t\x12\x33\n\x12processed_features\x18\x04 \x01(\x0b\x32\x17.google.protobuf.Struct\x12\x12\n\nimage_data\x18\x05 \x03(\x0c\x12\x12\n\nvideo_data\x18\x06 \x03(\x0c\x12\x12\n\naudio_data\x18\x07 \x03(\x0c\x12\x12\n\nmodalities\x18\x08 \x03(\t\"\xe3\x01\n\x10GenerateResponse\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12;\n\x05\x63hunk\x18\x02 \x01(\x0b\x32*.sglang.grpc.scheduler.GenerateStreamChunkH\x00\x12;\n\x08\x63omplete\x18\x03 \x01(\x0b\x32\'.sglang.grpc.scheduler.GenerateCompleteH\x00\x12\x35\n\x05\x65rror\x18\x04 \x01(\x0b\x32$.sglang.grpc.scheduler.GenerateErrorH\x00\x42\n\n\x08response\"\xf5\x01\n\x13GenerateStreamChunk\x12\x10\n\x08token_id\x18\x01 \x01(\x05\x12\x0c\n\x04text\x18\x02 \x01(\t\x12\x15\n\rprompt_tokens\x18\x03 \x01(\x05\x12\x19\n\x11\x63ompletion_tokens\x18\x04 \x01(\x05\x12\x15\n\rcached_tokens\x18\x05 \x01(\x05\x12\x31\n\x08logprobs\x18\x06 \x01(\x0b\x32\x1f.sglang.grpc.scheduler.LogProbs\x12\x15\n\rhidden_states\x18\x07 \x03(\x02\x12\x17\n\x0fgeneration_time\x18\x08 \x01(\x02\x12\x12\n\nqueue_time\x18\t \x01(\x05\"\xcd\x02\n\x10GenerateComplete\x12\x12\n\noutput_ids\x18\x01 \x03(\x05\x12\x13\n\x0boutput_text\x18\x02 \x01(\t\x12K\n\rfinish_reason\x18\x03 \x01(\x0e\x32\x34.sglang.grpc.scheduler.GenerateComplete.FinishReason\x12\x35\n\x0c\x61ll_logprobs\x18\x0b \x03(\x0b\x32\x1f.sglang.grpc.scheduler.LogProbs\x12>\n\x11\x61ll_hidden_states\x18\x0c \x03(\x0b\x32#.sglang.grpc.scheduler.HiddenStates\"L\n\x0c\x46inishReason\x12\x08\n\x04STOP\x10\x00\x12\n\n\x06LENGTH\x10\x01\x12\r\n\tEOS_TOKEN\x10\x02\x12\x0c\n\x08STOP_STR\x10\x03\x12\t\n\x05\x41\x42ORT\x10\x04\"K\n\rGenerateError\x12\x0f\n\x07message\x18\x01 \x01(\t\x12\x18\n\x10http_status_code\x18\x02 \x01(\t\x12\x0f\n\x07\x64\x65tails\x18\x03 \x01(\t\"\x84\x01\n\x08LogProbs\x12\x16\n\x0etoken_logprobs\x18\x01 \x03(\x02\x12\x11\n\ttoken_ids\x18\x02 \x03(\x05\x12\x38\n\x0ctop_logprobs\x18\x03 \x03(\x0b\x32\".sglang.grpc.scheduler.TopLogProbs\x12\x13\n\x0btoken_texts\x18\x04 \x03(\t\"E\n\x0bTopLogProbs\x12\x0e\n\x06values\x18\x01 \x03(\x02\x12\x11\n\ttoken_ids\x18\x02 \x03(\x05\x12\x13\n\x0btoken_texts\x18\x03 \x03(\t\"?\n\x0cHiddenStates\x12\x0e\n\x06values\x18\x01 \x03(\x02\x12\r\n\x05layer\x18\x02 \x01(\x05\x12\x10\n\x08position\x18\x03 \x01(\x05\"\xca\x02\n\x0c\x45mbedRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x38\n\ttokenized\x18\x02 \x01(\x0b\x32%.sglang.grpc.scheduler.TokenizedInput\x12:\n\tmm_inputs\x18\x04 \x01(\x0b\x32\'.sglang.grpc.scheduler.MultimodalInputs\x12>\n\x0fsampling_params\x18\x05 \x01(\x0b\x32%.sglang.grpc.scheduler.SamplingParams\x12\x13\n\x0blog_metrics\x18\x06 \x01(\x08\x12\x16\n\x0etoken_type_ids\x18\x07 \x03(\x05\x12\x1a\n\x12\x64\x61ta_parallel_rank\x18\x08 \x01(\x05\x12\x18\n\x10is_cross_encoder\x18\t \x01(\x08\x12\r\n\x05texts\x18\n \x03(\t\"\x9d\x01\n\rEmbedResponse\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x38\n\x08\x63omplete\x18\x02 \x01(\x0b\x32$.sglang.grpc.scheduler.EmbedCompleteH\x00\x12\x32\n\x05\x65rror\x18\x03 \x01(\x0b\x32!.sglang.grpc.scheduler.EmbedErrorH\x00\x42\n\n\x08response\"\xbc\x01\n\rEmbedComplete\x12\x11\n\tembedding\x18\x01 \x03(\x02\x12\x15\n\rprompt_tokens\x18\x02 \x01(\x05\x12\x15\n\rcached_tokens\x18\x03 \x01(\x05\x12\x15\n\rembedding_dim\x18\x04 \x01(\x05\x12\x17\n\x0fgeneration_time\x18\x05 \x01(\x02\x12:\n\x10\x62\x61tch_embeddings\x18\x06 \x03(\x0b\x32 .sglang.grpc.scheduler.Embedding\"*\n\tEmbedding\x12\x0e\n\x06values\x18\x01 \x03(\x02\x12\r\n\x05index\x18\x02 \x01(\x05\"<\n\nEmbedError\x12\x0f\n\x07message\x18\x01 \x01(\t\x12\x0c\n\x04\x63ode\x18\x02 \x01(\t\x12\x0f\n\x07\x64\x65tails\x18\x03 \x01(\t\"N\n\x12HealthCheckRequest\x12\x38\n\ttokenized\x18\x01 \x01(\x0b\x32%.sglang.grpc.scheduler.TokenizedInput\"7\n\x13HealthCheckResponse\x12\x0f\n\x07healthy\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"2\n\x0c\x41\x62ortRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x0e\n\x06reason\x18\x02 \x01(\t\"1\n\rAbortResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"I\n\x0fLoadLoRARequest\x12\x12\n\nadapter_id\x18\x01 \x01(\t\x12\x14\n\x0c\x61\x64\x61pter_path\x18\x02 \x01(\t\x12\x0c\n\x04rank\x18\x03 \x01(\x05\"H\n\x10LoadLoRAResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x12\n\nadapter_id\x18\x02 \x01(\t\x12\x0f\n\x07message\x18\x03 \x01(\t\"\'\n\x11UnloadLoRARequest\x12\x12\n\nadapter_id\x18\x01 \x01(\t\"6\n\x12UnloadLoRAResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"w\n\x14UpdateWeightsRequest\x12\x13\n\tdisk_path\x18\x01 \x01(\tH\x00\x12\x15\n\x0btensor_data\x18\x02 \x01(\x0cH\x00\x12\x14\n\nremote_url\x18\x03 \x01(\tH\x00\x12\x13\n\x0bweight_name\x18\x04 \x01(\tB\x08\n\x06source\"9\n\x15UpdateWeightsResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"-\n\x17GetInternalStateRequest\x12\x12\n\nstate_keys\x18\x01 \x03(\t\"B\n\x18GetInternalStateResponse\x12&\n\x05state\x18\x01 \x01(\x0b\x32\x17.google.protobuf.Struct\"A\n\x17SetInternalStateRequest\x12&\n\x05state\x18\x01 \x01(\x0b\x32\x17.google.protobuf.Struct\"<\n\x18SetInternalStateResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t2\xfe\x02\n\x0fSglangScheduler\x12]\n\x08Generate\x12&.sglang.grpc.scheduler.GenerateRequest\x1a\'.sglang.grpc.scheduler.GenerateResponse0\x01\x12R\n\x05\x45mbed\x12#.sglang.grpc.scheduler.EmbedRequest\x1a$.sglang.grpc.scheduler.EmbedResponse\x12\x64\n\x0bHealthCheck\x12).sglang.grpc.scheduler.HealthCheckRequest\x1a*.sglang.grpc.scheduler.HealthCheckResponse\x12R\n\x05\x41\x62ort\x12#.sglang.grpc.scheduler.AbortRequest\x1a$.sglang.grpc.scheduler.AbortResponseb\x06proto3') + +_globals = globals() +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'sglang_scheduler_pb2', _globals) +if not _descriptor._USE_C_DESCRIPTORS: + DESCRIPTOR._loaded_options = None + _globals['_SAMPLINGPARAMS_LOGITBIASENTRY']._loaded_options = None + _globals['_SAMPLINGPARAMS_LOGITBIASENTRY']._serialized_options = b'8\001' + _globals['_SAMPLINGPARAMS']._serialized_start=113 + _globals['_SAMPLINGPARAMS']._serialized_end=824 + _globals['_SAMPLINGPARAMS_LOGITBIASENTRY']._serialized_start=762 + _globals['_SAMPLINGPARAMS_LOGITBIASENTRY']._serialized_end=810 + _globals['_DISAGGREGATEDPARAMS']._serialized_start=826 + _globals['_DISAGGREGATEDPARAMS']._serialized_end=919 + _globals['_GENERATEREQUEST']._serialized_start=922 + _globals['_GENERATEREQUEST']._serialized_end=1539 + _globals['_TOKENIZEDINPUT']._serialized_start=1541 + _globals['_TOKENIZEDINPUT']._serialized_end=1599 + _globals['_MULTIMODALINPUTS']._serialized_start=1602 + _globals['_MULTIMODALINPUTS']._serialized_end=1813 + _globals['_GENERATERESPONSE']._serialized_start=1816 + _globals['_GENERATERESPONSE']._serialized_end=2043 + _globals['_GENERATESTREAMCHUNK']._serialized_start=2046 + _globals['_GENERATESTREAMCHUNK']._serialized_end=2291 + _globals['_GENERATECOMPLETE']._serialized_start=2294 + _globals['_GENERATECOMPLETE']._serialized_end=2627 + _globals['_GENERATECOMPLETE_FINISHREASON']._serialized_start=2551 + _globals['_GENERATECOMPLETE_FINISHREASON']._serialized_end=2627 + _globals['_GENERATEERROR']._serialized_start=2629 + _globals['_GENERATEERROR']._serialized_end=2704 + _globals['_LOGPROBS']._serialized_start=2707 + _globals['_LOGPROBS']._serialized_end=2839 + _globals['_TOPLOGPROBS']._serialized_start=2841 + _globals['_TOPLOGPROBS']._serialized_end=2910 + _globals['_HIDDENSTATES']._serialized_start=2912 + _globals['_HIDDENSTATES']._serialized_end=2975 + _globals['_EMBEDREQUEST']._serialized_start=2978 + _globals['_EMBEDREQUEST']._serialized_end=3308 + _globals['_EMBEDRESPONSE']._serialized_start=3311 + _globals['_EMBEDRESPONSE']._serialized_end=3468 + _globals['_EMBEDCOMPLETE']._serialized_start=3471 + _globals['_EMBEDCOMPLETE']._serialized_end=3659 + _globals['_EMBEDDING']._serialized_start=3661 + _globals['_EMBEDDING']._serialized_end=3703 + _globals['_EMBEDERROR']._serialized_start=3705 + _globals['_EMBEDERROR']._serialized_end=3765 + _globals['_HEALTHCHECKREQUEST']._serialized_start=3767 + _globals['_HEALTHCHECKREQUEST']._serialized_end=3845 + _globals['_HEALTHCHECKRESPONSE']._serialized_start=3847 + _globals['_HEALTHCHECKRESPONSE']._serialized_end=3902 + _globals['_ABORTREQUEST']._serialized_start=3904 + _globals['_ABORTREQUEST']._serialized_end=3954 + _globals['_ABORTRESPONSE']._serialized_start=3956 + _globals['_ABORTRESPONSE']._serialized_end=4005 + _globals['_LOADLORAREQUEST']._serialized_start=4007 + _globals['_LOADLORAREQUEST']._serialized_end=4080 + _globals['_LOADLORARESPONSE']._serialized_start=4082 + _globals['_LOADLORARESPONSE']._serialized_end=4154 + _globals['_UNLOADLORAREQUEST']._serialized_start=4156 + _globals['_UNLOADLORAREQUEST']._serialized_end=4195 + _globals['_UNLOADLORARESPONSE']._serialized_start=4197 + _globals['_UNLOADLORARESPONSE']._serialized_end=4251 + _globals['_UPDATEWEIGHTSREQUEST']._serialized_start=4253 + _globals['_UPDATEWEIGHTSREQUEST']._serialized_end=4372 + _globals['_UPDATEWEIGHTSRESPONSE']._serialized_start=4374 + _globals['_UPDATEWEIGHTSRESPONSE']._serialized_end=4431 + _globals['_GETINTERNALSTATEREQUEST']._serialized_start=4433 + _globals['_GETINTERNALSTATEREQUEST']._serialized_end=4478 + _globals['_GETINTERNALSTATERESPONSE']._serialized_start=4480 + _globals['_GETINTERNALSTATERESPONSE']._serialized_end=4546 + _globals['_SETINTERNALSTATEREQUEST']._serialized_start=4548 + _globals['_SETINTERNALSTATEREQUEST']._serialized_end=4613 + _globals['_SETINTERNALSTATERESPONSE']._serialized_start=4615 + _globals['_SETINTERNALSTATERESPONSE']._serialized_end=4675 + _globals['_SGLANGSCHEDULER']._serialized_start=4678 + _globals['_SGLANGSCHEDULER']._serialized_end=5060 +# @@protoc_insertion_point(module_scope) diff --git a/python/sglang/srt/grpc/sglang_scheduler_pb2.pyi b/python/sglang/srt/grpc/sglang_scheduler_pb2.pyi new file mode 100644 index 000000000..d9388463d --- /dev/null +++ b/python/sglang/srt/grpc/sglang_scheduler_pb2.pyi @@ -0,0 +1,427 @@ +import datetime + +from google.protobuf import timestamp_pb2 as _timestamp_pb2 +from google.protobuf import struct_pb2 as _struct_pb2 +from google.protobuf.internal import containers as _containers +from google.protobuf.internal import enum_type_wrapper as _enum_type_wrapper +from google.protobuf import descriptor as _descriptor +from google.protobuf import message as _message +from collections.abc import Iterable as _Iterable, Mapping as _Mapping +from typing import ClassVar as _ClassVar, Optional as _Optional, Union as _Union + +DESCRIPTOR: _descriptor.FileDescriptor + +class SamplingParams(_message.Message): + __slots__ = ("temperature", "top_p", "top_k", "min_p", "frequency_penalty", "presence_penalty", "repetition_penalty", "max_new_tokens", "stop", "stop_token_ids", "skip_special_tokens", "spaces_between_special_tokens", "regex", "json_schema", "ebnf_grammar", "lora_path", "n", "token_healing", "min_new_tokens", "ignore_eos", "no_stop_trim", "stream_interval", "logit_bias", "structural_tag", "custom_params") + class LogitBiasEntry(_message.Message): + __slots__ = ("key", "value") + KEY_FIELD_NUMBER: _ClassVar[int] + VALUE_FIELD_NUMBER: _ClassVar[int] + key: str + value: float + def __init__(self, key: _Optional[str] = ..., value: _Optional[float] = ...) -> None: ... + TEMPERATURE_FIELD_NUMBER: _ClassVar[int] + TOP_P_FIELD_NUMBER: _ClassVar[int] + TOP_K_FIELD_NUMBER: _ClassVar[int] + MIN_P_FIELD_NUMBER: _ClassVar[int] + FREQUENCY_PENALTY_FIELD_NUMBER: _ClassVar[int] + PRESENCE_PENALTY_FIELD_NUMBER: _ClassVar[int] + REPETITION_PENALTY_FIELD_NUMBER: _ClassVar[int] + MAX_NEW_TOKENS_FIELD_NUMBER: _ClassVar[int] + STOP_FIELD_NUMBER: _ClassVar[int] + STOP_TOKEN_IDS_FIELD_NUMBER: _ClassVar[int] + SKIP_SPECIAL_TOKENS_FIELD_NUMBER: _ClassVar[int] + SPACES_BETWEEN_SPECIAL_TOKENS_FIELD_NUMBER: _ClassVar[int] + REGEX_FIELD_NUMBER: _ClassVar[int] + JSON_SCHEMA_FIELD_NUMBER: _ClassVar[int] + EBNF_GRAMMAR_FIELD_NUMBER: _ClassVar[int] + LORA_PATH_FIELD_NUMBER: _ClassVar[int] + N_FIELD_NUMBER: _ClassVar[int] + TOKEN_HEALING_FIELD_NUMBER: _ClassVar[int] + MIN_NEW_TOKENS_FIELD_NUMBER: _ClassVar[int] + IGNORE_EOS_FIELD_NUMBER: _ClassVar[int] + NO_STOP_TRIM_FIELD_NUMBER: _ClassVar[int] + STREAM_INTERVAL_FIELD_NUMBER: _ClassVar[int] + LOGIT_BIAS_FIELD_NUMBER: _ClassVar[int] + STRUCTURAL_TAG_FIELD_NUMBER: _ClassVar[int] + CUSTOM_PARAMS_FIELD_NUMBER: _ClassVar[int] + temperature: float + top_p: float + top_k: int + min_p: float + frequency_penalty: float + presence_penalty: float + repetition_penalty: float + max_new_tokens: int + stop: _containers.RepeatedScalarFieldContainer[str] + stop_token_ids: _containers.RepeatedScalarFieldContainer[int] + skip_special_tokens: bool + spaces_between_special_tokens: bool + regex: str + json_schema: str + ebnf_grammar: str + lora_path: str + n: int + token_healing: bool + min_new_tokens: int + ignore_eos: bool + no_stop_trim: bool + stream_interval: int + logit_bias: _containers.ScalarMap[str, float] + structural_tag: str + custom_params: _struct_pb2.Struct + def __init__(self, temperature: _Optional[float] = ..., top_p: _Optional[float] = ..., top_k: _Optional[int] = ..., min_p: _Optional[float] = ..., frequency_penalty: _Optional[float] = ..., presence_penalty: _Optional[float] = ..., repetition_penalty: _Optional[float] = ..., max_new_tokens: _Optional[int] = ..., stop: _Optional[_Iterable[str]] = ..., stop_token_ids: _Optional[_Iterable[int]] = ..., skip_special_tokens: bool = ..., spaces_between_special_tokens: bool = ..., regex: _Optional[str] = ..., json_schema: _Optional[str] = ..., ebnf_grammar: _Optional[str] = ..., lora_path: _Optional[str] = ..., n: _Optional[int] = ..., token_healing: bool = ..., min_new_tokens: _Optional[int] = ..., ignore_eos: bool = ..., no_stop_trim: bool = ..., stream_interval: _Optional[int] = ..., logit_bias: _Optional[_Mapping[str, float]] = ..., structural_tag: _Optional[str] = ..., custom_params: _Optional[_Union[_struct_pb2.Struct, _Mapping]] = ...) -> None: ... + +class DisaggregatedParams(_message.Message): + __slots__ = ("bootstrap_host", "bootstrap_port", "bootstrap_room") + BOOTSTRAP_HOST_FIELD_NUMBER: _ClassVar[int] + BOOTSTRAP_PORT_FIELD_NUMBER: _ClassVar[int] + BOOTSTRAP_ROOM_FIELD_NUMBER: _ClassVar[int] + bootstrap_host: str + bootstrap_port: int + bootstrap_room: int + def __init__(self, bootstrap_host: _Optional[str] = ..., bootstrap_port: _Optional[int] = ..., bootstrap_room: _Optional[int] = ...) -> None: ... + +class GenerateRequest(_message.Message): + __slots__ = ("request_id", "tokenized", "mm_inputs", "sampling_params", "return_logprob", "logprob_start_len", "top_logprobs_num", "token_ids_logprob", "return_hidden_states", "disaggregated_params", "custom_logit_processor", "timestamp", "log_metrics", "input_embeds", "lora_id", "data_parallel_rank", "dp_balance_id") + REQUEST_ID_FIELD_NUMBER: _ClassVar[int] + TOKENIZED_FIELD_NUMBER: _ClassVar[int] + MM_INPUTS_FIELD_NUMBER: _ClassVar[int] + SAMPLING_PARAMS_FIELD_NUMBER: _ClassVar[int] + RETURN_LOGPROB_FIELD_NUMBER: _ClassVar[int] + LOGPROB_START_LEN_FIELD_NUMBER: _ClassVar[int] + TOP_LOGPROBS_NUM_FIELD_NUMBER: _ClassVar[int] + TOKEN_IDS_LOGPROB_FIELD_NUMBER: _ClassVar[int] + RETURN_HIDDEN_STATES_FIELD_NUMBER: _ClassVar[int] + DISAGGREGATED_PARAMS_FIELD_NUMBER: _ClassVar[int] + CUSTOM_LOGIT_PROCESSOR_FIELD_NUMBER: _ClassVar[int] + TIMESTAMP_FIELD_NUMBER: _ClassVar[int] + LOG_METRICS_FIELD_NUMBER: _ClassVar[int] + INPUT_EMBEDS_FIELD_NUMBER: _ClassVar[int] + LORA_ID_FIELD_NUMBER: _ClassVar[int] + DATA_PARALLEL_RANK_FIELD_NUMBER: _ClassVar[int] + DP_BALANCE_ID_FIELD_NUMBER: _ClassVar[int] + request_id: str + tokenized: TokenizedInput + mm_inputs: MultimodalInputs + sampling_params: SamplingParams + return_logprob: bool + logprob_start_len: int + top_logprobs_num: int + token_ids_logprob: _containers.RepeatedScalarFieldContainer[int] + return_hidden_states: bool + disaggregated_params: DisaggregatedParams + custom_logit_processor: str + timestamp: _timestamp_pb2.Timestamp + log_metrics: bool + input_embeds: _containers.RepeatedScalarFieldContainer[float] + lora_id: str + data_parallel_rank: int + dp_balance_id: int + def __init__(self, request_id: _Optional[str] = ..., tokenized: _Optional[_Union[TokenizedInput, _Mapping]] = ..., mm_inputs: _Optional[_Union[MultimodalInputs, _Mapping]] = ..., sampling_params: _Optional[_Union[SamplingParams, _Mapping]] = ..., return_logprob: bool = ..., logprob_start_len: _Optional[int] = ..., top_logprobs_num: _Optional[int] = ..., token_ids_logprob: _Optional[_Iterable[int]] = ..., return_hidden_states: bool = ..., disaggregated_params: _Optional[_Union[DisaggregatedParams, _Mapping]] = ..., custom_logit_processor: _Optional[str] = ..., timestamp: _Optional[_Union[datetime.datetime, _timestamp_pb2.Timestamp, _Mapping]] = ..., log_metrics: bool = ..., input_embeds: _Optional[_Iterable[float]] = ..., lora_id: _Optional[str] = ..., data_parallel_rank: _Optional[int] = ..., dp_balance_id: _Optional[int] = ...) -> None: ... + +class TokenizedInput(_message.Message): + __slots__ = ("original_text", "input_ids") + ORIGINAL_TEXT_FIELD_NUMBER: _ClassVar[int] + INPUT_IDS_FIELD_NUMBER: _ClassVar[int] + original_text: str + input_ids: _containers.RepeatedScalarFieldContainer[int] + def __init__(self, original_text: _Optional[str] = ..., input_ids: _Optional[_Iterable[int]] = ...) -> None: ... + +class MultimodalInputs(_message.Message): + __slots__ = ("image_urls", "video_urls", "audio_urls", "processed_features", "image_data", "video_data", "audio_data", "modalities") + IMAGE_URLS_FIELD_NUMBER: _ClassVar[int] + VIDEO_URLS_FIELD_NUMBER: _ClassVar[int] + AUDIO_URLS_FIELD_NUMBER: _ClassVar[int] + PROCESSED_FEATURES_FIELD_NUMBER: _ClassVar[int] + IMAGE_DATA_FIELD_NUMBER: _ClassVar[int] + VIDEO_DATA_FIELD_NUMBER: _ClassVar[int] + AUDIO_DATA_FIELD_NUMBER: _ClassVar[int] + MODALITIES_FIELD_NUMBER: _ClassVar[int] + image_urls: _containers.RepeatedScalarFieldContainer[str] + video_urls: _containers.RepeatedScalarFieldContainer[str] + audio_urls: _containers.RepeatedScalarFieldContainer[str] + processed_features: _struct_pb2.Struct + image_data: _containers.RepeatedScalarFieldContainer[bytes] + video_data: _containers.RepeatedScalarFieldContainer[bytes] + audio_data: _containers.RepeatedScalarFieldContainer[bytes] + modalities: _containers.RepeatedScalarFieldContainer[str] + def __init__(self, image_urls: _Optional[_Iterable[str]] = ..., video_urls: _Optional[_Iterable[str]] = ..., audio_urls: _Optional[_Iterable[str]] = ..., processed_features: _Optional[_Union[_struct_pb2.Struct, _Mapping]] = ..., image_data: _Optional[_Iterable[bytes]] = ..., video_data: _Optional[_Iterable[bytes]] = ..., audio_data: _Optional[_Iterable[bytes]] = ..., modalities: _Optional[_Iterable[str]] = ...) -> None: ... + +class GenerateResponse(_message.Message): + __slots__ = ("request_id", "chunk", "complete", "error") + REQUEST_ID_FIELD_NUMBER: _ClassVar[int] + CHUNK_FIELD_NUMBER: _ClassVar[int] + COMPLETE_FIELD_NUMBER: _ClassVar[int] + ERROR_FIELD_NUMBER: _ClassVar[int] + request_id: str + chunk: GenerateStreamChunk + complete: GenerateComplete + error: GenerateError + def __init__(self, request_id: _Optional[str] = ..., chunk: _Optional[_Union[GenerateStreamChunk, _Mapping]] = ..., complete: _Optional[_Union[GenerateComplete, _Mapping]] = ..., error: _Optional[_Union[GenerateError, _Mapping]] = ...) -> None: ... + +class GenerateStreamChunk(_message.Message): + __slots__ = ("token_id", "text", "prompt_tokens", "completion_tokens", "cached_tokens", "logprobs", "hidden_states", "generation_time", "queue_time") + TOKEN_ID_FIELD_NUMBER: _ClassVar[int] + TEXT_FIELD_NUMBER: _ClassVar[int] + PROMPT_TOKENS_FIELD_NUMBER: _ClassVar[int] + COMPLETION_TOKENS_FIELD_NUMBER: _ClassVar[int] + CACHED_TOKENS_FIELD_NUMBER: _ClassVar[int] + LOGPROBS_FIELD_NUMBER: _ClassVar[int] + HIDDEN_STATES_FIELD_NUMBER: _ClassVar[int] + GENERATION_TIME_FIELD_NUMBER: _ClassVar[int] + QUEUE_TIME_FIELD_NUMBER: _ClassVar[int] + token_id: int + text: str + prompt_tokens: int + completion_tokens: int + cached_tokens: int + logprobs: LogProbs + hidden_states: _containers.RepeatedScalarFieldContainer[float] + generation_time: float + queue_time: int + def __init__(self, token_id: _Optional[int] = ..., text: _Optional[str] = ..., prompt_tokens: _Optional[int] = ..., completion_tokens: _Optional[int] = ..., cached_tokens: _Optional[int] = ..., logprobs: _Optional[_Union[LogProbs, _Mapping]] = ..., hidden_states: _Optional[_Iterable[float]] = ..., generation_time: _Optional[float] = ..., queue_time: _Optional[int] = ...) -> None: ... + +class GenerateComplete(_message.Message): + __slots__ = ("output_ids", "output_text", "finish_reason", "all_logprobs", "all_hidden_states") + class FinishReason(int, metaclass=_enum_type_wrapper.EnumTypeWrapper): + __slots__ = () + STOP: _ClassVar[GenerateComplete.FinishReason] + LENGTH: _ClassVar[GenerateComplete.FinishReason] + EOS_TOKEN: _ClassVar[GenerateComplete.FinishReason] + STOP_STR: _ClassVar[GenerateComplete.FinishReason] + ABORT: _ClassVar[GenerateComplete.FinishReason] + STOP: GenerateComplete.FinishReason + LENGTH: GenerateComplete.FinishReason + EOS_TOKEN: GenerateComplete.FinishReason + STOP_STR: GenerateComplete.FinishReason + ABORT: GenerateComplete.FinishReason + OUTPUT_IDS_FIELD_NUMBER: _ClassVar[int] + OUTPUT_TEXT_FIELD_NUMBER: _ClassVar[int] + FINISH_REASON_FIELD_NUMBER: _ClassVar[int] + ALL_LOGPROBS_FIELD_NUMBER: _ClassVar[int] + ALL_HIDDEN_STATES_FIELD_NUMBER: _ClassVar[int] + output_ids: _containers.RepeatedScalarFieldContainer[int] + output_text: str + finish_reason: GenerateComplete.FinishReason + all_logprobs: _containers.RepeatedCompositeFieldContainer[LogProbs] + all_hidden_states: _containers.RepeatedCompositeFieldContainer[HiddenStates] + def __init__(self, output_ids: _Optional[_Iterable[int]] = ..., output_text: _Optional[str] = ..., finish_reason: _Optional[_Union[GenerateComplete.FinishReason, str]] = ..., all_logprobs: _Optional[_Iterable[_Union[LogProbs, _Mapping]]] = ..., all_hidden_states: _Optional[_Iterable[_Union[HiddenStates, _Mapping]]] = ...) -> None: ... + +class GenerateError(_message.Message): + __slots__ = ("message", "http_status_code", "details") + MESSAGE_FIELD_NUMBER: _ClassVar[int] + HTTP_STATUS_CODE_FIELD_NUMBER: _ClassVar[int] + DETAILS_FIELD_NUMBER: _ClassVar[int] + message: str + http_status_code: str + details: str + def __init__(self, message: _Optional[str] = ..., http_status_code: _Optional[str] = ..., details: _Optional[str] = ...) -> None: ... + +class LogProbs(_message.Message): + __slots__ = ("token_logprobs", "token_ids", "top_logprobs", "token_texts") + TOKEN_LOGPROBS_FIELD_NUMBER: _ClassVar[int] + TOKEN_IDS_FIELD_NUMBER: _ClassVar[int] + TOP_LOGPROBS_FIELD_NUMBER: _ClassVar[int] + TOKEN_TEXTS_FIELD_NUMBER: _ClassVar[int] + token_logprobs: _containers.RepeatedScalarFieldContainer[float] + token_ids: _containers.RepeatedScalarFieldContainer[int] + top_logprobs: _containers.RepeatedCompositeFieldContainer[TopLogProbs] + token_texts: _containers.RepeatedScalarFieldContainer[str] + def __init__(self, token_logprobs: _Optional[_Iterable[float]] = ..., token_ids: _Optional[_Iterable[int]] = ..., top_logprobs: _Optional[_Iterable[_Union[TopLogProbs, _Mapping]]] = ..., token_texts: _Optional[_Iterable[str]] = ...) -> None: ... + +class TopLogProbs(_message.Message): + __slots__ = ("values", "token_ids", "token_texts") + VALUES_FIELD_NUMBER: _ClassVar[int] + TOKEN_IDS_FIELD_NUMBER: _ClassVar[int] + TOKEN_TEXTS_FIELD_NUMBER: _ClassVar[int] + values: _containers.RepeatedScalarFieldContainer[float] + token_ids: _containers.RepeatedScalarFieldContainer[int] + token_texts: _containers.RepeatedScalarFieldContainer[str] + def __init__(self, values: _Optional[_Iterable[float]] = ..., token_ids: _Optional[_Iterable[int]] = ..., token_texts: _Optional[_Iterable[str]] = ...) -> None: ... + +class HiddenStates(_message.Message): + __slots__ = ("values", "layer", "position") + VALUES_FIELD_NUMBER: _ClassVar[int] + LAYER_FIELD_NUMBER: _ClassVar[int] + POSITION_FIELD_NUMBER: _ClassVar[int] + values: _containers.RepeatedScalarFieldContainer[float] + layer: int + position: int + def __init__(self, values: _Optional[_Iterable[float]] = ..., layer: _Optional[int] = ..., position: _Optional[int] = ...) -> None: ... + +class EmbedRequest(_message.Message): + __slots__ = ("request_id", "tokenized", "mm_inputs", "sampling_params", "log_metrics", "token_type_ids", "data_parallel_rank", "is_cross_encoder", "texts") + REQUEST_ID_FIELD_NUMBER: _ClassVar[int] + TOKENIZED_FIELD_NUMBER: _ClassVar[int] + MM_INPUTS_FIELD_NUMBER: _ClassVar[int] + SAMPLING_PARAMS_FIELD_NUMBER: _ClassVar[int] + LOG_METRICS_FIELD_NUMBER: _ClassVar[int] + TOKEN_TYPE_IDS_FIELD_NUMBER: _ClassVar[int] + DATA_PARALLEL_RANK_FIELD_NUMBER: _ClassVar[int] + IS_CROSS_ENCODER_FIELD_NUMBER: _ClassVar[int] + TEXTS_FIELD_NUMBER: _ClassVar[int] + request_id: str + tokenized: TokenizedInput + mm_inputs: MultimodalInputs + sampling_params: SamplingParams + log_metrics: bool + token_type_ids: _containers.RepeatedScalarFieldContainer[int] + data_parallel_rank: int + is_cross_encoder: bool + texts: _containers.RepeatedScalarFieldContainer[str] + def __init__(self, request_id: _Optional[str] = ..., tokenized: _Optional[_Union[TokenizedInput, _Mapping]] = ..., mm_inputs: _Optional[_Union[MultimodalInputs, _Mapping]] = ..., sampling_params: _Optional[_Union[SamplingParams, _Mapping]] = ..., log_metrics: bool = ..., token_type_ids: _Optional[_Iterable[int]] = ..., data_parallel_rank: _Optional[int] = ..., is_cross_encoder: bool = ..., texts: _Optional[_Iterable[str]] = ...) -> None: ... + +class EmbedResponse(_message.Message): + __slots__ = ("request_id", "complete", "error") + REQUEST_ID_FIELD_NUMBER: _ClassVar[int] + COMPLETE_FIELD_NUMBER: _ClassVar[int] + ERROR_FIELD_NUMBER: _ClassVar[int] + request_id: str + complete: EmbedComplete + error: EmbedError + def __init__(self, request_id: _Optional[str] = ..., complete: _Optional[_Union[EmbedComplete, _Mapping]] = ..., error: _Optional[_Union[EmbedError, _Mapping]] = ...) -> None: ... + +class EmbedComplete(_message.Message): + __slots__ = ("embedding", "prompt_tokens", "cached_tokens", "embedding_dim", "generation_time", "batch_embeddings") + EMBEDDING_FIELD_NUMBER: _ClassVar[int] + PROMPT_TOKENS_FIELD_NUMBER: _ClassVar[int] + CACHED_TOKENS_FIELD_NUMBER: _ClassVar[int] + EMBEDDING_DIM_FIELD_NUMBER: _ClassVar[int] + GENERATION_TIME_FIELD_NUMBER: _ClassVar[int] + BATCH_EMBEDDINGS_FIELD_NUMBER: _ClassVar[int] + embedding: _containers.RepeatedScalarFieldContainer[float] + prompt_tokens: int + cached_tokens: int + embedding_dim: int + generation_time: float + batch_embeddings: _containers.RepeatedCompositeFieldContainer[Embedding] + def __init__(self, embedding: _Optional[_Iterable[float]] = ..., prompt_tokens: _Optional[int] = ..., cached_tokens: _Optional[int] = ..., embedding_dim: _Optional[int] = ..., generation_time: _Optional[float] = ..., batch_embeddings: _Optional[_Iterable[_Union[Embedding, _Mapping]]] = ...) -> None: ... + +class Embedding(_message.Message): + __slots__ = ("values", "index") + VALUES_FIELD_NUMBER: _ClassVar[int] + INDEX_FIELD_NUMBER: _ClassVar[int] + values: _containers.RepeatedScalarFieldContainer[float] + index: int + def __init__(self, values: _Optional[_Iterable[float]] = ..., index: _Optional[int] = ...) -> None: ... + +class EmbedError(_message.Message): + __slots__ = ("message", "code", "details") + MESSAGE_FIELD_NUMBER: _ClassVar[int] + CODE_FIELD_NUMBER: _ClassVar[int] + DETAILS_FIELD_NUMBER: _ClassVar[int] + message: str + code: str + details: str + def __init__(self, message: _Optional[str] = ..., code: _Optional[str] = ..., details: _Optional[str] = ...) -> None: ... + +class HealthCheckRequest(_message.Message): + __slots__ = ("tokenized",) + TOKENIZED_FIELD_NUMBER: _ClassVar[int] + tokenized: TokenizedInput + def __init__(self, tokenized: _Optional[_Union[TokenizedInput, _Mapping]] = ...) -> None: ... + +class HealthCheckResponse(_message.Message): + __slots__ = ("healthy", "message") + HEALTHY_FIELD_NUMBER: _ClassVar[int] + MESSAGE_FIELD_NUMBER: _ClassVar[int] + healthy: bool + message: str + def __init__(self, healthy: bool = ..., message: _Optional[str] = ...) -> None: ... + +class AbortRequest(_message.Message): + __slots__ = ("request_id", "reason") + REQUEST_ID_FIELD_NUMBER: _ClassVar[int] + REASON_FIELD_NUMBER: _ClassVar[int] + request_id: str + reason: str + def __init__(self, request_id: _Optional[str] = ..., reason: _Optional[str] = ...) -> None: ... + +class AbortResponse(_message.Message): + __slots__ = ("success", "message") + SUCCESS_FIELD_NUMBER: _ClassVar[int] + MESSAGE_FIELD_NUMBER: _ClassVar[int] + success: bool + message: str + def __init__(self, success: bool = ..., message: _Optional[str] = ...) -> None: ... + +class LoadLoRARequest(_message.Message): + __slots__ = ("adapter_id", "adapter_path", "rank") + ADAPTER_ID_FIELD_NUMBER: _ClassVar[int] + ADAPTER_PATH_FIELD_NUMBER: _ClassVar[int] + RANK_FIELD_NUMBER: _ClassVar[int] + adapter_id: str + adapter_path: str + rank: int + def __init__(self, adapter_id: _Optional[str] = ..., adapter_path: _Optional[str] = ..., rank: _Optional[int] = ...) -> None: ... + +class LoadLoRAResponse(_message.Message): + __slots__ = ("success", "adapter_id", "message") + SUCCESS_FIELD_NUMBER: _ClassVar[int] + ADAPTER_ID_FIELD_NUMBER: _ClassVar[int] + MESSAGE_FIELD_NUMBER: _ClassVar[int] + success: bool + adapter_id: str + message: str + def __init__(self, success: bool = ..., adapter_id: _Optional[str] = ..., message: _Optional[str] = ...) -> None: ... + +class UnloadLoRARequest(_message.Message): + __slots__ = ("adapter_id",) + ADAPTER_ID_FIELD_NUMBER: _ClassVar[int] + adapter_id: str + def __init__(self, adapter_id: _Optional[str] = ...) -> None: ... + +class UnloadLoRAResponse(_message.Message): + __slots__ = ("success", "message") + SUCCESS_FIELD_NUMBER: _ClassVar[int] + MESSAGE_FIELD_NUMBER: _ClassVar[int] + success: bool + message: str + def __init__(self, success: bool = ..., message: _Optional[str] = ...) -> None: ... + +class UpdateWeightsRequest(_message.Message): + __slots__ = ("disk_path", "tensor_data", "remote_url", "weight_name") + DISK_PATH_FIELD_NUMBER: _ClassVar[int] + TENSOR_DATA_FIELD_NUMBER: _ClassVar[int] + REMOTE_URL_FIELD_NUMBER: _ClassVar[int] + WEIGHT_NAME_FIELD_NUMBER: _ClassVar[int] + disk_path: str + tensor_data: bytes + remote_url: str + weight_name: str + def __init__(self, disk_path: _Optional[str] = ..., tensor_data: _Optional[bytes] = ..., remote_url: _Optional[str] = ..., weight_name: _Optional[str] = ...) -> None: ... + +class UpdateWeightsResponse(_message.Message): + __slots__ = ("success", "message") + SUCCESS_FIELD_NUMBER: _ClassVar[int] + MESSAGE_FIELD_NUMBER: _ClassVar[int] + success: bool + message: str + def __init__(self, success: bool = ..., message: _Optional[str] = ...) -> None: ... + +class GetInternalStateRequest(_message.Message): + __slots__ = ("state_keys",) + STATE_KEYS_FIELD_NUMBER: _ClassVar[int] + state_keys: _containers.RepeatedScalarFieldContainer[str] + def __init__(self, state_keys: _Optional[_Iterable[str]] = ...) -> None: ... + +class GetInternalStateResponse(_message.Message): + __slots__ = ("state",) + STATE_FIELD_NUMBER: _ClassVar[int] + state: _struct_pb2.Struct + def __init__(self, state: _Optional[_Union[_struct_pb2.Struct, _Mapping]] = ...) -> None: ... + +class SetInternalStateRequest(_message.Message): + __slots__ = ("state",) + STATE_FIELD_NUMBER: _ClassVar[int] + state: _struct_pb2.Struct + def __init__(self, state: _Optional[_Union[_struct_pb2.Struct, _Mapping]] = ...) -> None: ... + +class SetInternalStateResponse(_message.Message): + __slots__ = ("success", "message") + SUCCESS_FIELD_NUMBER: _ClassVar[int] + MESSAGE_FIELD_NUMBER: _ClassVar[int] + success: bool + message: str + def __init__(self, success: bool = ..., message: _Optional[str] = ...) -> None: ... diff --git a/python/sglang/srt/grpc/sglang_scheduler_pb2_grpc.py b/python/sglang/srt/grpc/sglang_scheduler_pb2_grpc.py new file mode 100644 index 000000000..d9bdf0462 --- /dev/null +++ b/python/sglang/srt/grpc/sglang_scheduler_pb2_grpc.py @@ -0,0 +1,236 @@ +# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! +"""Client and server classes corresponding to protobuf-defined services.""" +import grpc +import warnings + +from . import sglang_scheduler_pb2 as sglang__scheduler__pb2 + +GRPC_GENERATED_VERSION = '1.74.0' +GRPC_VERSION = grpc.__version__ +_version_not_supported = False + +try: + from grpc._utilities import first_version_is_lower + _version_not_supported = first_version_is_lower(GRPC_VERSION, GRPC_GENERATED_VERSION) +except ImportError: + _version_not_supported = True + +if _version_not_supported: + raise RuntimeError( + f'The grpc package installed is at version {GRPC_VERSION},' + + f' but the generated code in sglang_scheduler_pb2_grpc.py depends on' + + f' grpcio>={GRPC_GENERATED_VERSION}.' + + f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}' + + f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.' + ) + + +class SglangSchedulerStub(object): + """Service definition for SGLang scheduler communication + This protocol bridges the Rust router and Python scheduler + """ + + def __init__(self, channel): + """Constructor. + + Args: + channel: A grpc.Channel. + """ + self.Generate = channel.unary_stream( + '/sglang.grpc.scheduler.SglangScheduler/Generate', + request_serializer=sglang__scheduler__pb2.GenerateRequest.SerializeToString, + response_deserializer=sglang__scheduler__pb2.GenerateResponse.FromString, + _registered_method=True) + self.Embed = channel.unary_unary( + '/sglang.grpc.scheduler.SglangScheduler/Embed', + request_serializer=sglang__scheduler__pb2.EmbedRequest.SerializeToString, + response_deserializer=sglang__scheduler__pb2.EmbedResponse.FromString, + _registered_method=True) + self.HealthCheck = channel.unary_unary( + '/sglang.grpc.scheduler.SglangScheduler/HealthCheck', + request_serializer=sglang__scheduler__pb2.HealthCheckRequest.SerializeToString, + response_deserializer=sglang__scheduler__pb2.HealthCheckResponse.FromString, + _registered_method=True) + self.Abort = channel.unary_unary( + '/sglang.grpc.scheduler.SglangScheduler/Abort', + request_serializer=sglang__scheduler__pb2.AbortRequest.SerializeToString, + response_deserializer=sglang__scheduler__pb2.AbortResponse.FromString, + _registered_method=True) + + +class SglangSchedulerServicer(object): + """Service definition for SGLang scheduler communication + This protocol bridges the Rust router and Python scheduler + """ + + def Generate(self, request, context): + """Submit a generation request (supports streaming) + """ + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def Embed(self, request, context): + """Submit an embedding request + """ + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def HealthCheck(self, request, context): + """Health check and metrics + """ + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def Abort(self, request, context): + """Abort a running request + """ + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + +def add_SglangSchedulerServicer_to_server(servicer, server): + rpc_method_handlers = { + 'Generate': grpc.unary_stream_rpc_method_handler( + servicer.Generate, + request_deserializer=sglang__scheduler__pb2.GenerateRequest.FromString, + response_serializer=sglang__scheduler__pb2.GenerateResponse.SerializeToString, + ), + 'Embed': grpc.unary_unary_rpc_method_handler( + servicer.Embed, + request_deserializer=sglang__scheduler__pb2.EmbedRequest.FromString, + response_serializer=sglang__scheduler__pb2.EmbedResponse.SerializeToString, + ), + 'HealthCheck': grpc.unary_unary_rpc_method_handler( + servicer.HealthCheck, + request_deserializer=sglang__scheduler__pb2.HealthCheckRequest.FromString, + response_serializer=sglang__scheduler__pb2.HealthCheckResponse.SerializeToString, + ), + 'Abort': grpc.unary_unary_rpc_method_handler( + servicer.Abort, + request_deserializer=sglang__scheduler__pb2.AbortRequest.FromString, + response_serializer=sglang__scheduler__pb2.AbortResponse.SerializeToString, + ), + } + generic_handler = grpc.method_handlers_generic_handler( + 'sglang.grpc.scheduler.SglangScheduler', rpc_method_handlers) + server.add_generic_rpc_handlers((generic_handler,)) + server.add_registered_method_handlers('sglang.grpc.scheduler.SglangScheduler', rpc_method_handlers) + + + # This class is part of an EXPERIMENTAL API. +class SglangScheduler(object): + """Service definition for SGLang scheduler communication + This protocol bridges the Rust router and Python scheduler + """ + + @staticmethod + def Generate(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_stream( + request, + target, + '/sglang.grpc.scheduler.SglangScheduler/Generate', + sglang__scheduler__pb2.GenerateRequest.SerializeToString, + sglang__scheduler__pb2.GenerateResponse.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) + + @staticmethod + def Embed(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary( + request, + target, + '/sglang.grpc.scheduler.SglangScheduler/Embed', + sglang__scheduler__pb2.EmbedRequest.SerializeToString, + sglang__scheduler__pb2.EmbedResponse.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) + + @staticmethod + def HealthCheck(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary( + request, + target, + '/sglang.grpc.scheduler.SglangScheduler/HealthCheck', + sglang__scheduler__pb2.HealthCheckRequest.SerializeToString, + sglang__scheduler__pb2.HealthCheckResponse.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) + + @staticmethod + def Abort(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary( + request, + target, + '/sglang.grpc.scheduler.SglangScheduler/Abort', + sglang__scheduler__pb2.AbortRequest.SerializeToString, + sglang__scheduler__pb2.AbortResponse.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 68061ae97..1af94c457 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -2238,6 +2238,7 @@ class ServerArgs: args.pp_size = args.pipeline_parallel_size args.dp_size = args.data_parallel_size args.ep_size = args.expert_parallel_size + attrs = [attr.name for attr in dataclasses.fields(cls)] return cls(**{attr: getattr(args, attr) for attr in attrs}) diff --git a/sgl-router/src/grpc/client.rs b/sgl-router/src/grpc/client.rs index 8561b79db..efd141a0b 100644 --- a/sgl-router/src/grpc/client.rs +++ b/sgl-router/src/grpc/client.rs @@ -37,21 +37,6 @@ impl SglangSchedulerClient { Ok(Self { client }) } - /// Initialize the connection - pub async fn initialize( - &mut self, - client_id: String, - ) -> Result> { - let request = Request::new(proto::InitializeRequest { - client_id, - client_version: "0.1.0".to_string(), - mode: proto::initialize_request::Mode::Regular as i32, - }); - - let response = self.client.initialize(request).await?; - Ok(response.into_inner()) - } - /// Submit a generation request (returns streaming response) pub async fn generate_stream( &mut self, @@ -68,7 +53,10 @@ impl SglangSchedulerClient { ) -> Result> { debug!("Sending health check request"); let request = Request::new(proto::HealthCheckRequest { - include_detailed_metrics: false, + tokenized: Some(proto::TokenizedInput { + original_text: "Hello".to_string(), + input_ids: vec![9906], // Mock token ID for "Hello" + }), }); let response = self.client.health_check(request).await?; @@ -87,21 +75,6 @@ impl SglangSchedulerClient { self.client.abort(request).await?; Ok(()) } - - /// Flush cache - pub async fn flush_cache( - &mut self, - flush_all: bool, - session_ids: &[String], - ) -> Result> { - let request = Request::new(proto::FlushCacheRequest { - flush_all, - session_ids: session_ids.to_vec(), - }); - - let response = self.client.flush_cache(request).await?; - Ok(response.into_inner()) - } } #[cfg(test)] @@ -111,14 +84,13 @@ mod tests { #[test] fn test_proto_types_compilation() { // Test that protobuf types can be constructed - let init_req = proto::InitializeRequest { - client_id: "test-client".to_string(), - client_version: "0.1.0".to_string(), - mode: 0, + let health_req = proto::HealthCheckRequest { + tokenized: Some(proto::TokenizedInput { + original_text: "test".to_string(), + input_ids: vec![1296], + }), }; - assert_eq!(init_req.client_id, "test-client"); - assert_eq!(init_req.client_version, "0.1.0"); - assert_eq!(init_req.mode, 0); + assert!(health_req.tokenized.is_some()); } #[test] @@ -134,9 +106,10 @@ mod tests { let gen_req = proto::GenerateRequest { request_id: "test-req-123".to_string(), - input: Some(proto::generate_request::Input::Text( - "Hello world".to_string(), - )), + tokenized: Some(proto::TokenizedInput { + original_text: "Hello world".to_string(), + input_ids: vec![9906, 1917], // Mock token IDs for "Hello world" + }), sampling_params: Some(sampling_params), return_logprob: true, logprob_start_len: 0, @@ -145,8 +118,8 @@ mod tests { }; assert_eq!(gen_req.request_id, "test-req-123"); - if let Some(proto::generate_request::Input::Text(text)) = &gen_req.input { - assert_eq!(text, "Hello world"); + if let Some(ref tokenized) = &gen_req.tokenized { + assert_eq!(tokenized.original_text, "Hello world"); } assert!(gen_req.return_logprob); assert_eq!(gen_req.top_logprobs_num, 5); @@ -160,9 +133,12 @@ mod tests { #[test] fn test_health_check_request() { let health_req = proto::HealthCheckRequest { - include_detailed_metrics: true, + tokenized: Some(proto::TokenizedInput { + original_text: "test".to_string(), + input_ids: vec![1296], // Mock token ID for "test" + }), }; - assert!(health_req.include_detailed_metrics); + assert!(health_req.tokenized.is_some()); } #[test] @@ -175,17 +151,6 @@ mod tests { assert_eq!(abort_req.reason, "User canceled"); } - #[test] - fn test_flush_cache_request() { - let flush_req = proto::FlushCacheRequest { - flush_all: true, - session_ids: vec!["session1".to_string(), "session2".to_string()], - }; - assert!(flush_req.flush_all); - assert_eq!(flush_req.session_ids.len(), 2); - assert_eq!(flush_req.session_ids[0], "session1"); - } - #[test] fn test_sampling_params_defaults() { let params = proto::SamplingParams::default(); @@ -214,38 +179,29 @@ mod tests { assert_eq!(mm_inputs.modalities[0], "image"); } - #[test] - fn test_session_params() { - let session_params = proto::SessionParams { - session_id: "sess-789".to_string(), - request_id: "req-101".to_string(), - offset: 100, - replace: true, - drop_previous_output: false, - }; - - assert_eq!(session_params.session_id, "sess-789"); - assert_eq!(session_params.request_id, "req-101"); - assert_eq!(session_params.offset, 100); - assert!(session_params.replace); - assert!(!session_params.drop_previous_output); - } + // TODO: SessionParams not in current proto - skip test + // #[test] + // fn test_session_params() { ... } #[test] fn test_embed_request() { let embed_req = proto::EmbedRequest { request_id: "embed-req-202".to_string(), - input: Some(proto::embed_request::Input::Text( - "This is a test sentence for embedding".to_string(), - )), + tokenized: Some(proto::TokenizedInput { + original_text: "This is a test sentence for embedding".to_string(), + input_ids: vec![2028, 374, 264, 1296, 11914, 369, 28537], // Mock token IDs + }), log_metrics: true, data_parallel_rank: 0, ..Default::default() }; assert_eq!(embed_req.request_id, "embed-req-202"); - if let Some(proto::embed_request::Input::Text(text)) = &embed_req.input { - assert_eq!(text, "This is a test sentence for embedding"); + if let Some(ref tokenized) = &embed_req.tokenized { + assert_eq!( + tokenized.original_text, + "This is a test sentence for embedding" + ); } assert!(embed_req.log_metrics); assert_eq!(embed_req.data_parallel_rank, 0); @@ -292,36 +248,7 @@ mod tests { assert_eq!(chunk.queue_time, 10); } - #[test] - fn test_model_info() { - let model_info = proto::ModelInfo { - model_name: "Meta-Llama-3-8B-Instruct".to_string(), - max_context_length: 8192, - vocab_size: 128256, - supports_tool_calling: true, - supports_vision: false, - special_tokens: vec![ - "<|begin_of_text|>".to_string(), - "<|end_of_text|>".to_string(), - ], - model_type: "llama".to_string(), - num_layers: 32, - hidden_size: 4096, - num_attention_heads: 32, - num_key_value_heads: 8, - tokenizer_type: "llama".to_string(), - eos_token_ids: vec![128001, 128009], - pad_token_id: 128001, - bos_token_id: 128000, - }; - - assert_eq!(model_info.model_name, "Meta-Llama-3-8B-Instruct"); - assert_eq!(model_info.max_context_length, 8192); - assert_eq!(model_info.vocab_size, 128256); - assert!(model_info.supports_tool_calling); - assert!(!model_info.supports_vision); - assert_eq!(model_info.special_tokens.len(), 2); - assert_eq!(model_info.num_layers, 32); - assert_eq!(model_info.eos_token_ids, vec![128001, 128009]); - } + // TODO: ModelInfo not in current proto - skip test + // #[test] + // fn test_model_info() { ... } } diff --git a/sgl-router/src/proto/sglang_scheduler.proto b/sgl-router/src/proto/sglang_scheduler.proto index 1ea2855a4..e4c87925e 100644 --- a/sgl-router/src/proto/sglang_scheduler.proto +++ b/sgl-router/src/proto/sglang_scheduler.proto @@ -8,9 +8,6 @@ import "google/protobuf/struct.proto"; // Service definition for SGLang scheduler communication // This protocol bridges the Rust router and Python scheduler service SglangScheduler { - // Initialize connection and get model info - rpc Initialize(InitializeRequest) returns (InitializeResponse); - // Submit a generation request (supports streaming) rpc Generate(GenerateRequest) returns (stream GenerateResponse); @@ -23,8 +20,6 @@ service SglangScheduler { // Abort a running request rpc Abort(AbortRequest) returns (AbortResponse); - // Flush KV cache - rpc FlushCache(FlushCacheRequest) returns (FlushCacheResponse); } // ===================== @@ -75,14 +70,6 @@ message SamplingParams { google.protobuf.Struct custom_params = 25; } -// Session parameters for continual prompting -message SessionParams { - string session_id = 1; - string request_id = 2; - int32 offset = 3; - bool replace = 4; - bool drop_previous_output = 5; -} // Disaggregated serving parameters message DisaggregatedParams { @@ -91,87 +78,6 @@ message DisaggregatedParams { int32 bootstrap_room = 3; } -// ===================== -// Initialize -// ===================== - -message InitializeRequest { - string client_id = 1; - string client_version = 2; - - // Operating mode - enum Mode { - REGULAR = 0; // Normal mode with local scheduler - PREFILL = 1; // Prefill-only mode for disaggregated serving - DECODE = 2; // Decode-only mode for disaggregated serving - } - Mode mode = 3; -} - -message InitializeResponse { - bool success = 1; - string scheduler_version = 2; - - // Model information - ModelInfo model_info = 3; - - // Server capabilities - ServerCapabilities capabilities = 4; - - // Error message if success is false - string error_message = 5; -} - -message ModelInfo { - string model_name = 1; - int32 max_context_length = 2; - int32 vocab_size = 3; - bool supports_tool_calling = 4; - bool supports_vision = 5; - repeated string special_tokens = 6; - - // Additional model metadata - string model_type = 7; - int32 num_layers = 8; - int32 hidden_size = 9; - int32 num_attention_heads = 10; - int32 num_key_value_heads = 11; - - // Tokenizer info - string tokenizer_type = 12; - repeated int32 eos_token_ids = 13; - int32 pad_token_id = 14; - int32 bos_token_id = 15; -} - -message ServerCapabilities { - bool continuous_batching = 1; - bool disaggregated_serving = 2; - bool speculative_decoding = 3; - int32 max_batch_size = 4; - int32 max_num_batched_tokens = 5; - int32 max_prefill_tokens = 6; - string attention_backend = 7; // "flashinfer", "triton", "torch" - - // Additional capabilities - bool supports_lora = 8; - bool supports_grammar = 9; - bool supports_multimodal = 10; - repeated string supported_modalities = 11; // ["image", "video", "audio"] - bool supports_custom_logit_processor = 12; - bool supports_session = 13; - - // Hardware info - int32 num_gpus = 14; - string gpu_type = 15; - int64 total_gpu_memory = 16; - - // Parallelism info - int32 tensor_parallel_size = 17; - int32 pipeline_parallel_size = 18; - int32 data_parallel_size = 19; -} - // ===================== // Generate Request // ===================== @@ -179,49 +85,43 @@ message ServerCapabilities { message GenerateRequest { string request_id = 1; - // Input can be either text or tokenized - oneof input { - string text = 2; - TokenizedInput tokenized = 3; - } + // Input must be tokenized (no raw text) + TokenizedInput tokenized = 2; // Multimodal inputs - MultimodalInputs mm_inputs = 4; + MultimodalInputs mm_inputs = 3; // Generation parameters - SamplingParams sampling_params = 5; + SamplingParams sampling_params = 4; // Return options - bool return_logprob = 6; - int32 logprob_start_len = 7; - int32 top_logprobs_num = 8; - repeated int32 token_ids_logprob = 9; - bool return_hidden_states = 10; - - // Session management - SessionParams session_params = 11; + bool return_logprob = 5; + int32 logprob_start_len = 6; + int32 top_logprobs_num = 7; + repeated int32 token_ids_logprob = 8; + bool return_hidden_states = 9; // For disaggregated serving - DisaggregatedParams disaggregated_params = 12; + DisaggregatedParams disaggregated_params = 10; // Custom logit processor (serialized) - string custom_logit_processor = 13; + string custom_logit_processor = 11; // Request metadata - google.protobuf.Timestamp timestamp = 14; - bool log_metrics = 15; + google.protobuf.Timestamp timestamp = 12; + bool log_metrics = 13; // Input embeddings (alternative to text/tokens) - repeated float input_embeds = 16; + repeated float input_embeds = 14; // LoRA adapter ID (if pre-loaded) - string lora_id = 17; + string lora_id = 15; // Data parallel routing - int32 data_parallel_rank = 18; + int32 data_parallel_rank = 16; // For load balancing - int32 dp_balance_id = 19; + int32 dp_balance_id = 17; } message TokenizedInput { @@ -303,19 +203,6 @@ message GenerateComplete { } FinishReason finish_reason = 3; - // Final counts - int32 prompt_tokens = 4; - int32 completion_tokens = 5; - int32 cached_tokens = 6; - - // Performance metrics - float total_generation_time = 7; - float time_to_first_token = 8; - float tokens_per_second = 9; - - // Spec decode metrics - int32 spec_verify_count = 10; - // All logprobs if requested repeated LogProbs all_logprobs = 11; @@ -359,10 +246,8 @@ message HiddenStates { message EmbedRequest { string request_id = 1; - oneof input { - string text = 2; - TokenizedInput tokenized = 3; - } + // Input must be tokenized (no raw text) + TokenizedInput tokenized = 2; // Multimodal inputs MultimodalInputs mm_inputs = 4; @@ -422,39 +307,13 @@ message EmbedError { // ===================== message HealthCheckRequest { - bool include_detailed_metrics = 1; + // Input for health test generation (must be tokenized) + TokenizedInput tokenized = 1; } message HealthCheckResponse { bool healthy = 1; - - // Current load metrics - int32 num_requests_running = 2; - int32 num_requests_waiting = 3; - float gpu_cache_usage = 4; - float gpu_memory_usage = 5; - - // KV cache metrics - int32 kv_cache_total_blocks = 6; - int32 kv_cache_used_blocks = 7; - float kv_cache_hit_rate = 8; - - // Additional metrics - int32 num_grammar_queue_requests = 9; - float generation_throughput = 10; // tokens/sec - float average_queue_time = 11; // seconds - float average_generation_time = 12; // seconds - - // System metrics - float cpu_usage = 13; - int64 memory_usage = 14; - - // Disaggregation metrics - int32 num_prefill_requests = 15; - int32 num_decode_requests = 16; - - // Detailed metrics (optional) - google.protobuf.Struct detailed_metrics = 17; + string message = 2; } message AbortRequest { @@ -467,17 +326,6 @@ message AbortResponse { string message = 2; } -message FlushCacheRequest { - bool flush_all = 1; - repeated string session_ids = 2; // Flush specific sessions -} - -message FlushCacheResponse { - bool success = 1; - int32 num_entries_flushed = 2; - int64 memory_freed = 3; // bytes - string message = 4; -} // ===================== // Additional Operations (Future)