Implement Standalone gRPC Server for SGLang Python Scheduler (#10283)
This commit is contained in:
@@ -22,17 +22,19 @@ repos:
|
|||||||
rev: 5.13.2
|
rev: 5.13.2
|
||||||
hooks:
|
hooks:
|
||||||
- id: isort
|
- id: isort
|
||||||
|
exclude: '^python/sglang/srt/grpc/.*_pb2\.py$|^python/sglang/srt/grpc/.*_pb2_grpc\.py$|^python/sglang/srt/grpc/.*_pb2\.pyi$|^python/sglang/srt/grpc/.*_pb2_grpc\.pyi$'
|
||||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||||
rev: v0.11.7
|
rev: v0.11.7
|
||||||
hooks:
|
hooks:
|
||||||
- id: ruff
|
- id: ruff
|
||||||
args: [--select=F401, --fixable=F401]
|
args: [--select=F401, --fixable=F401]
|
||||||
files: ^(benchmark/|docs/|examples/)
|
files: ^(benchmark/|docs/|examples/)
|
||||||
exclude: \.ipynb$
|
exclude: \.ipynb$|^python/sglang/srt/grpc/.*_pb2\.py$|^python/sglang/srt/grpc/.*_pb2_grpc\.py$|^python/sglang/srt/grpc/.*_pb2\.pyi$|^python/sglang/srt/grpc/.*_pb2_grpc\.pyi$
|
||||||
- repo: https://github.com/psf/black
|
- repo: https://github.com/psf/black
|
||||||
rev: 24.10.0
|
rev: 24.10.0
|
||||||
hooks:
|
hooks:
|
||||||
- id: black-jupyter
|
- id: black-jupyter
|
||||||
|
exclude: '^python/sglang/srt/grpc/.*_pb2\.py$|^python/sglang/srt/grpc/.*_pb2_grpc\.py$|^python/sglang/srt/grpc/.*_pb2\.pyi$|^python/sglang/srt/grpc/.*_pb2_grpc\.pyi$'
|
||||||
- repo: https://github.com/codespell-project/codespell
|
- repo: https://github.com/codespell-project/codespell
|
||||||
rev: v2.4.1
|
rev: v2.4.1
|
||||||
hooks:
|
hooks:
|
||||||
@@ -42,7 +44,11 @@ repos:
|
|||||||
exclude: |
|
exclude: |
|
||||||
(?x)^(
|
(?x)^(
|
||||||
test/srt/test_reasoning_parser\.py|
|
test/srt/test_reasoning_parser\.py|
|
||||||
docs/advanced_features/vlm_query\.ipynb
|
docs/advanced_features/vlm_query\.ipynb|
|
||||||
|
python/sglang/srt/grpc/.*_pb2\.py|
|
||||||
|
python/sglang/srt/grpc/.*_pb2_grpc\.py|
|
||||||
|
python/sglang/srt/grpc/.*_pb2\.pyi|
|
||||||
|
python/sglang/srt/grpc/.*_pb2_grpc\.pyi
|
||||||
)$
|
)$
|
||||||
- repo: https://github.com/pre-commit/mirrors-clang-format
|
- repo: https://github.com/pre-commit/mirrors-clang-format
|
||||||
rev: v18.1.8
|
rev: v18.1.8
|
||||||
|
|||||||
580
python/sglang/srt/entrypoints/grpc_request_manager.py
Normal file
580
python/sglang/srt/entrypoints/grpc_request_manager.py
Normal file
@@ -0,0 +1,580 @@
|
|||||||
|
"""
|
||||||
|
gRPC Request Manager - Orchestrates request lifecycle without tokenization.
|
||||||
|
Mimics TokenizerManager's state management and ZMQ communication patterns.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import dataclasses
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import signal
|
||||||
|
import sys
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
|
import grpc
|
||||||
|
import zmq
|
||||||
|
import zmq.asyncio
|
||||||
|
|
||||||
|
from sglang.srt.managers.io_struct import (
|
||||||
|
AbortReq,
|
||||||
|
BatchEmbeddingOut,
|
||||||
|
BatchTokenIDOut,
|
||||||
|
HealthCheckOutput,
|
||||||
|
TokenizedEmbeddingReqInput,
|
||||||
|
TokenizedGenerateReqInput,
|
||||||
|
)
|
||||||
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||||
|
from sglang.srt.utils import get_zmq_socket, kill_process_tree
|
||||||
|
from sglang.utils import get_exception_traceback
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class GrpcSignalHandler:
|
||||||
|
"""Minimal signal handler for gRPC server - delegates real crash handling to scheduler."""
|
||||||
|
|
||||||
|
def __init__(self, grpc_manager):
|
||||||
|
self.grpc_manager = grpc_manager
|
||||||
|
|
||||||
|
def sigterm_handler(self, signum=None, frame=None):
|
||||||
|
"""Handle SIGTERM by gracefully shutting down gRPC server."""
|
||||||
|
logger.warning(
|
||||||
|
f"SIGTERM received. {signum=} {frame=}. Shutting down gRPC server..."
|
||||||
|
)
|
||||||
|
self.grpc_manager.gracefully_exit = True
|
||||||
|
|
||||||
|
def running_phase_sigquit_handler(self, signum=None, frame=None):
|
||||||
|
"""Handle SIGQUIT from failed scheduler process."""
|
||||||
|
logger.error(
|
||||||
|
"Received SIGQUIT from scheduler process. Scheduler failed, shutting down gRPC server."
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
"Note: Crash dumps are handled by the scheduler process, not the gRPC server."
|
||||||
|
)
|
||||||
|
# Just exit cleanly - the scheduler handles crash dumps
|
||||||
|
kill_process_tree(os.getpid(), include_parent=True)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class GrpcReqState:
|
||||||
|
"""State tracking for a gRPC request."""
|
||||||
|
|
||||||
|
# Request identification
|
||||||
|
request_id: str
|
||||||
|
grpc_context: Optional[grpc.aio.ServicerContext]
|
||||||
|
|
||||||
|
# Communication
|
||||||
|
out_queue: asyncio.Queue
|
||||||
|
finished: bool
|
||||||
|
event: asyncio.Event
|
||||||
|
obj: Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput]
|
||||||
|
|
||||||
|
# Metrics (same as TokenizerManager's ReqState)
|
||||||
|
created_time: float
|
||||||
|
finished_time: float = 0.0
|
||||||
|
first_token_time: float = 0.0
|
||||||
|
last_time: float = 0.0
|
||||||
|
last_completion_tokens: int = 1
|
||||||
|
|
||||||
|
# Streaming state
|
||||||
|
last_output_offset: int = 0
|
||||||
|
stream_finished: bool = False
|
||||||
|
|
||||||
|
# Output accumulation
|
||||||
|
text: str = ""
|
||||||
|
output_ids: List[int] = dataclasses.field(default_factory=list)
|
||||||
|
input_token_logprobs_val: List[float] = dataclasses.field(default_factory=list)
|
||||||
|
input_token_logprobs_idx: List[int] = dataclasses.field(default_factory=list)
|
||||||
|
output_token_logprobs_val: List[float] = dataclasses.field(default_factory=list)
|
||||||
|
output_token_logprobs_idx: List[int] = dataclasses.field(default_factory=list)
|
||||||
|
input_top_logprobs_val: List[List[float]] = dataclasses.field(default_factory=list)
|
||||||
|
input_top_logprobs_idx: List[List[int]] = dataclasses.field(default_factory=list)
|
||||||
|
output_top_logprobs_val: List[List[float]] = dataclasses.field(default_factory=list)
|
||||||
|
output_top_logprobs_idx: List[List[int]] = dataclasses.field(default_factory=list)
|
||||||
|
|
||||||
|
# Session state
|
||||||
|
session_id: Optional[str] = None
|
||||||
|
is_session_request: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
class GrpcRequestManager:
|
||||||
|
"""
|
||||||
|
Manages gRPC request lifecycle, mimicking TokenizerManager's orchestration
|
||||||
|
behaviors without tokenization.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
server_args: ServerArgs,
|
||||||
|
port_args: PortArgs,
|
||||||
|
):
|
||||||
|
"""Initialize the gRPC request manager."""
|
||||||
|
self.server_args = server_args
|
||||||
|
self.port_args = port_args
|
||||||
|
|
||||||
|
# ZMQ Communication Setup (same pattern as TokenizerManager)
|
||||||
|
context = zmq.asyncio.Context(2)
|
||||||
|
|
||||||
|
# Socket for receiving outputs from scheduler
|
||||||
|
self.recv_from_scheduler = get_zmq_socket(
|
||||||
|
context, zmq.PULL, port_args.detokenizer_ipc_name, bind=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Socket for sending requests to scheduler
|
||||||
|
self.send_to_scheduler = get_zmq_socket(
|
||||||
|
context, zmq.PUSH, port_args.scheduler_input_ipc_name, bind=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# State Management (from TokenizerManager)
|
||||||
|
self.rid_to_state: Dict[str, GrpcReqState] = {}
|
||||||
|
self.asyncio_tasks: set = set()
|
||||||
|
self.gracefully_exit = False
|
||||||
|
self.no_create_loop = False
|
||||||
|
self.event_loop = None
|
||||||
|
|
||||||
|
# Pause/Resume Control
|
||||||
|
self.is_pause = False
|
||||||
|
self.is_pause_cond = asyncio.Condition()
|
||||||
|
|
||||||
|
# Metrics
|
||||||
|
self.request_counter = 0
|
||||||
|
self.request_counter_lock = asyncio.Lock()
|
||||||
|
self.last_receive_tstamp = time.time()
|
||||||
|
|
||||||
|
# Crash dump for debugging
|
||||||
|
self.crash_dump_request_list = []
|
||||||
|
self.crash_dump_performed = False
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"GrpcRequestManager initialized with ZMQ IPC: "
|
||||||
|
f"recv={port_args.detokenizer_ipc_name}, "
|
||||||
|
f"send={port_args.scheduler_input_ipc_name}"
|
||||||
|
)
|
||||||
|
|
||||||
|
async def generate_request(
|
||||||
|
self,
|
||||||
|
obj: TokenizedGenerateReqInput,
|
||||||
|
request_id: Optional[str] = None,
|
||||||
|
grpc_context: Optional[grpc.aio.ServicerContext] = None,
|
||||||
|
) -> asyncio.Queue:
|
||||||
|
"""
|
||||||
|
Submit a generation request to the scheduler.
|
||||||
|
Returns a queue for streaming outputs.
|
||||||
|
"""
|
||||||
|
# Generate request ID if not provided
|
||||||
|
if request_id is None:
|
||||||
|
async with self.request_counter_lock:
|
||||||
|
request_id = f"grpc-{self.request_counter}"
|
||||||
|
self.request_counter += 1
|
||||||
|
|
||||||
|
obj.rid = request_id
|
||||||
|
|
||||||
|
# TODO: support log_request
|
||||||
|
|
||||||
|
# Create request state
|
||||||
|
state = GrpcReqState(
|
||||||
|
request_id=request_id,
|
||||||
|
grpc_context=grpc_context,
|
||||||
|
out_queue=asyncio.Queue(),
|
||||||
|
finished=False,
|
||||||
|
event=asyncio.Event(),
|
||||||
|
obj=obj,
|
||||||
|
created_time=time.time(),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Track session if needed
|
||||||
|
if hasattr(obj, "session_params") and obj.session_params:
|
||||||
|
state.session_id = obj.session_params.session_id
|
||||||
|
state.is_session_request = True
|
||||||
|
|
||||||
|
# Register state
|
||||||
|
self.rid_to_state[request_id] = state
|
||||||
|
self.record_request_for_crash_dump(obj)
|
||||||
|
|
||||||
|
# Send to scheduler via ZMQ
|
||||||
|
try:
|
||||||
|
await self._send_to_scheduler(obj)
|
||||||
|
except Exception as e:
|
||||||
|
# Clean up on failure
|
||||||
|
del self.rid_to_state[request_id]
|
||||||
|
raise RuntimeError(f"Failed to send request to scheduler: {e}")
|
||||||
|
|
||||||
|
return state.out_queue
|
||||||
|
|
||||||
|
async def embedding_request(
|
||||||
|
self,
|
||||||
|
obj: TokenizedEmbeddingReqInput,
|
||||||
|
request_id: Optional[str] = None,
|
||||||
|
) -> asyncio.Future:
|
||||||
|
"""
|
||||||
|
Submit an embedding request to the scheduler.
|
||||||
|
Returns a future that will contain the embedding result.
|
||||||
|
"""
|
||||||
|
# Generate request ID if not provided
|
||||||
|
if request_id is None:
|
||||||
|
async with self.request_counter_lock:
|
||||||
|
request_id = f"grpc-embed-{self.request_counter}"
|
||||||
|
self.request_counter += 1
|
||||||
|
|
||||||
|
obj.rid = request_id
|
||||||
|
|
||||||
|
# Create request state
|
||||||
|
state = GrpcReqState(
|
||||||
|
request_id=request_id,
|
||||||
|
grpc_context=None,
|
||||||
|
out_queue=asyncio.Queue(),
|
||||||
|
finished=False,
|
||||||
|
event=asyncio.Event(),
|
||||||
|
obj=obj,
|
||||||
|
created_time=time.time(),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Register state
|
||||||
|
self.rid_to_state[request_id] = state
|
||||||
|
|
||||||
|
# Create future for result
|
||||||
|
future = asyncio.Future()
|
||||||
|
|
||||||
|
# Send to scheduler
|
||||||
|
try:
|
||||||
|
await self._send_to_scheduler(obj)
|
||||||
|
except Exception as e:
|
||||||
|
del self.rid_to_state[request_id]
|
||||||
|
future.set_exception(e)
|
||||||
|
return future
|
||||||
|
|
||||||
|
# Wait for result in background
|
||||||
|
async def wait_for_result():
|
||||||
|
try:
|
||||||
|
# Wait for completion
|
||||||
|
await state.event.wait()
|
||||||
|
# Get result from queue
|
||||||
|
result = await state.out_queue.get()
|
||||||
|
future.set_result(result)
|
||||||
|
except Exception as e:
|
||||||
|
future.set_exception(e)
|
||||||
|
finally:
|
||||||
|
# Clean up
|
||||||
|
if request_id in self.rid_to_state:
|
||||||
|
del self.rid_to_state[request_id]
|
||||||
|
|
||||||
|
asyncio.create_task(wait_for_result())
|
||||||
|
return future
|
||||||
|
|
||||||
|
async def abort_request(self, request_id: str) -> bool:
|
||||||
|
"""Abort a running request."""
|
||||||
|
if request_id not in self.rid_to_state:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Send abort to scheduler
|
||||||
|
abort_req = AbortReq(rid=request_id)
|
||||||
|
try:
|
||||||
|
await self._send_to_scheduler(abort_req)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to send abort request: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Mark as finished
|
||||||
|
state = self.rid_to_state.get(request_id)
|
||||||
|
if state:
|
||||||
|
state.finished = True
|
||||||
|
state.stream_finished = True
|
||||||
|
state.event.set()
|
||||||
|
|
||||||
|
# Send abort notification to output queue
|
||||||
|
await state.out_queue.put({"error": "Request aborted", "abort": True})
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def pause_generation(self):
|
||||||
|
"""Pause generation processing."""
|
||||||
|
async with self.is_pause_cond:
|
||||||
|
self.is_pause = True
|
||||||
|
logger.info("Generation paused")
|
||||||
|
|
||||||
|
async def resume_generation(self):
|
||||||
|
"""Resume generation processing."""
|
||||||
|
async with self.is_pause_cond:
|
||||||
|
self.is_pause = False
|
||||||
|
self.is_pause_cond.notify_all()
|
||||||
|
logger.info("Generation resumed")
|
||||||
|
|
||||||
|
async def handle_loop(self):
|
||||||
|
"""
|
||||||
|
Main event loop - processes outputs from scheduler.
|
||||||
|
Mimics TokenizerManager's handle_loop.
|
||||||
|
"""
|
||||||
|
while not self.gracefully_exit:
|
||||||
|
try:
|
||||||
|
# Receive from scheduler
|
||||||
|
recv_obj = await self.recv_from_scheduler.recv_pyobj()
|
||||||
|
self.last_receive_tstamp = time.time()
|
||||||
|
|
||||||
|
# Check for pause
|
||||||
|
async with self.is_pause_cond:
|
||||||
|
while self.is_pause:
|
||||||
|
await self.is_pause_cond.wait()
|
||||||
|
|
||||||
|
# Handle different output types
|
||||||
|
if isinstance(recv_obj, BatchTokenIDOut):
|
||||||
|
await self._handle_batch_output(recv_obj)
|
||||||
|
elif isinstance(recv_obj, BatchEmbeddingOut):
|
||||||
|
await self._handle_embedding_output(recv_obj)
|
||||||
|
elif isinstance(recv_obj, HealthCheckOutput):
|
||||||
|
await self._handle_health_check_output(recv_obj)
|
||||||
|
else:
|
||||||
|
logger.warning(f"Unknown output type: {type(recv_obj)}")
|
||||||
|
|
||||||
|
except zmq.error.Again:
|
||||||
|
# Timeout, check if we should exit
|
||||||
|
if self.gracefully_exit:
|
||||||
|
break
|
||||||
|
continue
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Handle loop error: {e}\n{get_exception_traceback()}")
|
||||||
|
if self.gracefully_exit:
|
||||||
|
break
|
||||||
|
|
||||||
|
async def _handle_batch_output(self, batch_out: BatchTokenIDOut):
|
||||||
|
"""Handle batch generation output from scheduler."""
|
||||||
|
# Process each request in the batch
|
||||||
|
for i, rid in enumerate(batch_out.rids):
|
||||||
|
if rid not in self.rid_to_state:
|
||||||
|
continue
|
||||||
|
|
||||||
|
state = self.rid_to_state[rid]
|
||||||
|
|
||||||
|
# Update metrics
|
||||||
|
now = time.time()
|
||||||
|
if state.first_token_time == 0.0:
|
||||||
|
state.first_token_time = now
|
||||||
|
state.last_time = now
|
||||||
|
|
||||||
|
# Extract output for this request
|
||||||
|
output_data = {
|
||||||
|
"request_id": rid,
|
||||||
|
"text": batch_out.decoded_texts[i] if batch_out.decoded_texts else "",
|
||||||
|
"token_ids": batch_out.output_ids[i] if batch_out.output_ids else [],
|
||||||
|
"finished": batch_out.finished_reasons[i] is not None,
|
||||||
|
"meta_info": {
|
||||||
|
"prompt_tokens": (
|
||||||
|
batch_out.prompt_tokens[i] if batch_out.prompt_tokens else 0
|
||||||
|
),
|
||||||
|
"completion_tokens": (
|
||||||
|
batch_out.completion_tokens[i]
|
||||||
|
if batch_out.completion_tokens
|
||||||
|
else 0
|
||||||
|
),
|
||||||
|
"finish_reason": (
|
||||||
|
str(batch_out.finished_reasons[i])
|
||||||
|
if batch_out.finished_reasons[i]
|
||||||
|
else None
|
||||||
|
),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add logprobs if available
|
||||||
|
if batch_out.output_token_logprobs_val and i < len(
|
||||||
|
batch_out.output_token_logprobs_val
|
||||||
|
):
|
||||||
|
output_data["logprobs"] = {
|
||||||
|
"tokens": batch_out.output_token_logprobs_val[i],
|
||||||
|
"top_logprobs": (
|
||||||
|
batch_out.output_top_logprobs_val[i]
|
||||||
|
if batch_out.output_top_logprobs_val
|
||||||
|
and i < len(batch_out.output_top_logprobs_val)
|
||||||
|
else None
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
# Update state
|
||||||
|
if output_data["text"]:
|
||||||
|
state.text += output_data["text"][state.last_output_offset :]
|
||||||
|
state.last_output_offset = len(output_data["text"])
|
||||||
|
|
||||||
|
if output_data["token_ids"]:
|
||||||
|
state.output_ids.extend(output_data["token_ids"])
|
||||||
|
|
||||||
|
# Send to output queue
|
||||||
|
await state.out_queue.put(output_data)
|
||||||
|
|
||||||
|
# Handle completion
|
||||||
|
if output_data["finished"]:
|
||||||
|
state.finished = True
|
||||||
|
state.finished_time = now
|
||||||
|
state.stream_finished = True
|
||||||
|
state.event.set()
|
||||||
|
|
||||||
|
# Remove from tracking after a delay
|
||||||
|
async def cleanup():
|
||||||
|
await asyncio.sleep(5.0)
|
||||||
|
if rid in self.rid_to_state:
|
||||||
|
del self.rid_to_state[rid]
|
||||||
|
|
||||||
|
asyncio.create_task(cleanup())
|
||||||
|
|
||||||
|
async def _handle_embedding_output(self, batch_out: BatchEmbeddingOut):
|
||||||
|
"""Handle batch embedding output from scheduler."""
|
||||||
|
for i, rid in enumerate(batch_out.rids):
|
||||||
|
if rid not in self.rid_to_state:
|
||||||
|
continue
|
||||||
|
|
||||||
|
state = self.rid_to_state[rid]
|
||||||
|
|
||||||
|
# Create result
|
||||||
|
result = {
|
||||||
|
"request_id": rid,
|
||||||
|
"embedding": batch_out.embeddings[i],
|
||||||
|
"prompt_tokens": (
|
||||||
|
batch_out.prompt_tokens[i] if batch_out.prompt_tokens else 0
|
||||||
|
),
|
||||||
|
"finish_reason": (
|
||||||
|
batch_out.finish_reason[i] if batch_out.finish_reason else None
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
# Send result
|
||||||
|
await state.out_queue.put(result)
|
||||||
|
|
||||||
|
# Mark as finished
|
||||||
|
state.finished = True
|
||||||
|
state.finished_time = time.time()
|
||||||
|
state.event.set()
|
||||||
|
|
||||||
|
async def _handle_health_check_output(self, health_out: HealthCheckOutput):
|
||||||
|
"""Handle health check output from scheduler."""
|
||||||
|
rid = health_out.rid
|
||||||
|
|
||||||
|
if rid not in self.rid_to_state:
|
||||||
|
logger.warning(f"Health check output for unknown request: {rid}")
|
||||||
|
return
|
||||||
|
|
||||||
|
state = self.rid_to_state[rid]
|
||||||
|
|
||||||
|
# Create health check result
|
||||||
|
result = {
|
||||||
|
"request_id": rid,
|
||||||
|
"healthy": True, # If we got a response, scheduler is healthy
|
||||||
|
"output_text": (
|
||||||
|
health_out.output_str if hasattr(health_out, "output_str") else ""
|
||||||
|
),
|
||||||
|
"finish_reason": (
|
||||||
|
health_out.finish_reason
|
||||||
|
if hasattr(health_out, "finish_reason")
|
||||||
|
else "stop"
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
# Send result
|
||||||
|
await state.out_queue.put(result)
|
||||||
|
|
||||||
|
# Mark as finished
|
||||||
|
state.finished = True
|
||||||
|
state.finished_time = time.time()
|
||||||
|
state.event.set()
|
||||||
|
|
||||||
|
async def _send_to_scheduler(self, obj):
|
||||||
|
"""Send an object to the scheduler via ZMQ."""
|
||||||
|
try:
|
||||||
|
self.send_to_scheduler.send_pyobj(obj)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to send to scheduler: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def record_request_for_crash_dump(self, obj):
|
||||||
|
"""Record request for potential crash dump."""
|
||||||
|
if len(self.crash_dump_request_list) < 100:
|
||||||
|
self.crash_dump_request_list.append(
|
||||||
|
{
|
||||||
|
"time": time.time(),
|
||||||
|
"request_id": getattr(obj, "rid", "unknown"),
|
||||||
|
"type": type(obj).__name__,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
async def shutdown(self):
|
||||||
|
"""Gracefully shutdown the request manager."""
|
||||||
|
logger.info("Shutting down GrpcRequestManager")
|
||||||
|
self.gracefully_exit = True
|
||||||
|
|
||||||
|
# Cancel all pending requests
|
||||||
|
for rid, state in self.rid_to_state.items():
|
||||||
|
if not state.finished:
|
||||||
|
await state.out_queue.put(
|
||||||
|
{"error": "Server shutting down", "shutdown": True}
|
||||||
|
)
|
||||||
|
state.finished = True
|
||||||
|
state.event.set()
|
||||||
|
|
||||||
|
# Wait for tasks to complete
|
||||||
|
if self.asyncio_tasks:
|
||||||
|
await asyncio.gather(*list(self.asyncio_tasks), return_exceptions=True)
|
||||||
|
|
||||||
|
# Close ZMQ sockets
|
||||||
|
self.recv_from_scheduler.close()
|
||||||
|
self.send_to_scheduler.close()
|
||||||
|
|
||||||
|
logger.info("GrpcRequestManager shutdown complete")
|
||||||
|
|
||||||
|
def get_server_info(self) -> Dict[str, Any]:
|
||||||
|
"""Get server information for health checks."""
|
||||||
|
return {
|
||||||
|
"active_requests": len(self.rid_to_state),
|
||||||
|
"paused": self.is_pause,
|
||||||
|
"last_receive_time": self.last_receive_tstamp,
|
||||||
|
}
|
||||||
|
|
||||||
|
def auto_create_handle_loop(self):
|
||||||
|
"""Automatically create and start the handle_loop task, matching TokenizerManager pattern."""
|
||||||
|
if self.no_create_loop:
|
||||||
|
return
|
||||||
|
|
||||||
|
self.no_create_loop = True
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
self.asyncio_tasks.add(
|
||||||
|
loop.create_task(print_exception_wrapper(self.handle_loop))
|
||||||
|
)
|
||||||
|
|
||||||
|
self.event_loop = loop
|
||||||
|
|
||||||
|
# We cannot add signal handler when the grpc manager is not in
|
||||||
|
# the main thread due to the CPython limitation.
|
||||||
|
if threading.current_thread() is threading.main_thread():
|
||||||
|
signal_handler = GrpcSignalHandler(self)
|
||||||
|
loop.add_signal_handler(signal.SIGTERM, signal_handler.sigterm_handler)
|
||||||
|
# Update the signal handler for the process. It overrides the sigquit handler in the launch phase.
|
||||||
|
loop.add_signal_handler(
|
||||||
|
signal.SIGQUIT, signal_handler.running_phase_sigquit_handler
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
"Signal handler is not added because the grpc request manager is "
|
||||||
|
"not in the main thread. This disables graceful shutdown of the "
|
||||||
|
"grpc request manager when SIGTERM is received."
|
||||||
|
)
|
||||||
|
self.asyncio_tasks.add(
|
||||||
|
loop.create_task(print_exception_wrapper(self.sigterm_watchdog))
|
||||||
|
)
|
||||||
|
|
||||||
|
async def sigterm_watchdog(self):
|
||||||
|
"""Watchdog to handle SIGTERM gracefully, matching TokenizerManager pattern."""
|
||||||
|
while not self.gracefully_exit:
|
||||||
|
await asyncio.sleep(1.0)
|
||||||
|
|
||||||
|
|
||||||
|
async def print_exception_wrapper(func):
|
||||||
|
"""
|
||||||
|
Sometimes an asyncio function does not print exception.
|
||||||
|
We do another wrapper to handle the exception.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
await func()
|
||||||
|
except Exception:
|
||||||
|
traceback = get_exception_traceback()
|
||||||
|
logger.error(f"GrpcRequestManager hit an exception: {traceback}")
|
||||||
|
if hasattr(func, "__self__") and isinstance(func.__self__, GrpcRequestManager):
|
||||||
|
func.__self__.dump_requests_before_crash()
|
||||||
|
kill_process_tree(os.getpid(), include_parent=True)
|
||||||
|
sys.exit(1)
|
||||||
680
python/sglang/srt/entrypoints/grpc_server.py
Normal file
680
python/sglang/srt/entrypoints/grpc_server.py
Normal file
@@ -0,0 +1,680 @@
|
|||||||
|
"""
|
||||||
|
Standalone gRPC Server for SGLang - Fully separated from HTTP server.
|
||||||
|
Uses GrpcRequestManager for orchestration without tokenization.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import multiprocessing as mp
|
||||||
|
import os
|
||||||
|
import signal
|
||||||
|
import time
|
||||||
|
from concurrent import futures
|
||||||
|
from typing import AsyncIterator, Dict, Optional, Tuple
|
||||||
|
|
||||||
|
import grpc
|
||||||
|
from grpc_reflection.v1alpha import reflection
|
||||||
|
|
||||||
|
from sglang.srt.entrypoints.grpc_request_manager import GrpcRequestManager
|
||||||
|
from sglang.srt.grpc import sglang_scheduler_pb2, sglang_scheduler_pb2_grpc
|
||||||
|
from sglang.srt.managers.data_parallel_controller import (
|
||||||
|
run_data_parallel_controller_process,
|
||||||
|
)
|
||||||
|
from sglang.srt.managers.io_struct import (
|
||||||
|
TokenizedEmbeddingReqInput,
|
||||||
|
TokenizedGenerateReqInput,
|
||||||
|
)
|
||||||
|
from sglang.srt.managers.scheduler import run_scheduler_process
|
||||||
|
from sglang.srt.sampling.sampling_params import SamplingParams as SGLSamplingParams
|
||||||
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||||
|
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
||||||
|
from sglang.srt.utils import configure_logger, prepare_model_and_tokenizer
|
||||||
|
from sglang.utils import get_exception_traceback
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
HEALTH_CHECK_TIMEOUT = int(os.getenv("SGLANG_HEALTH_CHECK_TIMEOUT", 20))
|
||||||
|
|
||||||
|
|
||||||
|
def _launch_scheduler_process_only(
|
||||||
|
server_args: ServerArgs,
|
||||||
|
port_args: Optional[PortArgs] = None,
|
||||||
|
) -> Tuple[Dict, PortArgs, list]:
|
||||||
|
"""
|
||||||
|
Launch only the scheduler process(es) without tokenizer/detokenizer.
|
||||||
|
Returns scheduler info, port args, and list of scheduler processes.
|
||||||
|
"""
|
||||||
|
# Configure global environment
|
||||||
|
configure_logger(server_args)
|
||||||
|
server_args.check_server_args()
|
||||||
|
|
||||||
|
# Allocate ports for inter-process communications
|
||||||
|
if port_args is None:
|
||||||
|
port_args = PortArgs.init_new(server_args)
|
||||||
|
logger.info(f"{server_args=}")
|
||||||
|
|
||||||
|
# Prepare model and tokenizer paths
|
||||||
|
server_args.model_path, server_args.tokenizer_path = prepare_model_and_tokenizer(
|
||||||
|
server_args.model_path, server_args.tokenizer_path
|
||||||
|
)
|
||||||
|
|
||||||
|
scheduler_procs = []
|
||||||
|
if server_args.dp_size == 1:
|
||||||
|
memory_saver_adapter = TorchMemorySaverAdapter.create(
|
||||||
|
enable=server_args.enable_memory_saver
|
||||||
|
)
|
||||||
|
scheduler_pipe_readers = []
|
||||||
|
|
||||||
|
nnodes_per_tp_group = max(server_args.nnodes // server_args.pp_size, 1)
|
||||||
|
tp_size_per_node = server_args.tp_size // nnodes_per_tp_group
|
||||||
|
tp_rank_range = range(
|
||||||
|
tp_size_per_node * (server_args.node_rank % nnodes_per_tp_group),
|
||||||
|
tp_size_per_node * (server_args.node_rank % nnodes_per_tp_group + 1),
|
||||||
|
)
|
||||||
|
|
||||||
|
pp_size_per_node = max(server_args.pp_size // server_args.nnodes, 1)
|
||||||
|
pp_rank_range = range(
|
||||||
|
pp_size_per_node * (server_args.node_rank // nnodes_per_tp_group),
|
||||||
|
pp_size_per_node * (server_args.node_rank // nnodes_per_tp_group + 1),
|
||||||
|
)
|
||||||
|
|
||||||
|
for pp_rank in pp_rank_range:
|
||||||
|
for tp_rank in tp_rank_range:
|
||||||
|
reader, writer = mp.Pipe(duplex=False)
|
||||||
|
gpu_id = (
|
||||||
|
server_args.base_gpu_id
|
||||||
|
+ ((pp_rank % pp_size_per_node) * tp_size_per_node)
|
||||||
|
+ (tp_rank % tp_size_per_node) * server_args.gpu_id_step
|
||||||
|
)
|
||||||
|
moe_ep_rank = tp_rank // (server_args.tp_size // server_args.ep_size)
|
||||||
|
proc = mp.Process(
|
||||||
|
target=run_scheduler_process,
|
||||||
|
args=(
|
||||||
|
server_args,
|
||||||
|
port_args,
|
||||||
|
gpu_id,
|
||||||
|
tp_rank,
|
||||||
|
moe_ep_rank,
|
||||||
|
pp_rank,
|
||||||
|
None,
|
||||||
|
writer,
|
||||||
|
None,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
with memory_saver_adapter.configure_subprocess():
|
||||||
|
proc.start()
|
||||||
|
scheduler_procs.append(proc)
|
||||||
|
scheduler_pipe_readers.append(reader)
|
||||||
|
else:
|
||||||
|
# Launch the data parallel controller
|
||||||
|
reader, writer = mp.Pipe(duplex=False)
|
||||||
|
scheduler_pipe_readers = [reader]
|
||||||
|
proc = mp.Process(
|
||||||
|
target=run_data_parallel_controller_process,
|
||||||
|
args=(server_args, port_args, writer),
|
||||||
|
)
|
||||||
|
proc.start()
|
||||||
|
scheduler_procs.append(proc)
|
||||||
|
|
||||||
|
# TODO(CatherineSue): handle cases for multi-node
|
||||||
|
|
||||||
|
# Wait for all scheduler processes to be ready
|
||||||
|
scheduler_infos = []
|
||||||
|
for i, reader in enumerate(scheduler_pipe_readers):
|
||||||
|
try:
|
||||||
|
data = reader.recv()
|
||||||
|
except EOFError:
|
||||||
|
logger.error(
|
||||||
|
f"Rank {i} scheduler is dead. Please check if there are relevant logs."
|
||||||
|
)
|
||||||
|
scheduler_procs[i].join()
|
||||||
|
logger.error(f"Exit code: {scheduler_procs[i].exitcode}")
|
||||||
|
raise RuntimeError(f"Failed to initialize scheduler rank {i}")
|
||||||
|
|
||||||
|
if data.get("status") != "ready":
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Scheduler rank {i} initialization failed: {data.get('error', 'Unknown error')}"
|
||||||
|
)
|
||||||
|
scheduler_infos.append(data)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"All {len(scheduler_procs)} scheduler process(es) initialized successfully"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Return the first scheduler's info (they should all be the same)
|
||||||
|
return scheduler_infos[0], port_args, scheduler_procs
|
||||||
|
|
||||||
|
|
||||||
|
class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer):
|
||||||
|
"""
|
||||||
|
Standalone gRPC service implementation using GrpcRequestManager.
|
||||||
|
Fully separated from HTTP server with its own process and no shared globals.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
request_manager: GrpcRequestManager,
|
||||||
|
server_args: ServerArgs,
|
||||||
|
model_info: Dict,
|
||||||
|
):
|
||||||
|
"""Initialize the standalone gRPC service."""
|
||||||
|
self.request_manager = request_manager
|
||||||
|
self.server_args = server_args
|
||||||
|
self.model_info = model_info
|
||||||
|
self.start_time = time.time()
|
||||||
|
|
||||||
|
# Start the request manager's event loop using auto_create_handle_loop
|
||||||
|
self.request_manager.auto_create_handle_loop()
|
||||||
|
|
||||||
|
logger.info("Standalone gRPC scheduler service initialized")
|
||||||
|
|
||||||
|
async def Generate(
|
||||||
|
self,
|
||||||
|
request: sglang_scheduler_pb2.GenerateRequest,
|
||||||
|
context: grpc.aio.ServicerContext,
|
||||||
|
) -> AsyncIterator[sglang_scheduler_pb2.GenerateResponse]:
|
||||||
|
"""Handle generation requests with streaming responses."""
|
||||||
|
logger.info(f"Generation request: {request.request_id}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Convert gRPC request to internal format
|
||||||
|
tokenized_req = self._convert_generate_request(request)
|
||||||
|
|
||||||
|
# Submit to request manager
|
||||||
|
output_queue = await self.request_manager.generate_request(
|
||||||
|
obj=tokenized_req,
|
||||||
|
request_id=request.request_id,
|
||||||
|
grpc_context=context,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Stream outputs
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
# Get output with timeout
|
||||||
|
output = await asyncio.wait_for(output_queue.get(), timeout=4)
|
||||||
|
|
||||||
|
# Check for errors
|
||||||
|
if "error" in output:
|
||||||
|
yield sglang_scheduler_pb2.GenerateResponse(
|
||||||
|
request_id=request.request_id,
|
||||||
|
error=sglang_scheduler_pb2.GenerateError(
|
||||||
|
message=output["error"],
|
||||||
|
http_status_code=(
|
||||||
|
"500" if "abort" not in output else "499"
|
||||||
|
),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
break
|
||||||
|
|
||||||
|
# Check if finished
|
||||||
|
if output.get("finished", False):
|
||||||
|
# Send completion
|
||||||
|
yield self._create_completion_response(
|
||||||
|
request.request_id, output
|
||||||
|
)
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
# Send chunk
|
||||||
|
yield self._create_chunk_response(request.request_id, output)
|
||||||
|
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
# Check if context is still active
|
||||||
|
if context.cancelled():
|
||||||
|
# Abort the request
|
||||||
|
await self.request_manager.abort_request(request.request_id)
|
||||||
|
break
|
||||||
|
continue
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Generate failed: {e}\n{get_exception_traceback()}")
|
||||||
|
yield sglang_scheduler_pb2.GenerateResponse(
|
||||||
|
request_id=request.request_id,
|
||||||
|
error=sglang_scheduler_pb2.GenerateError(
|
||||||
|
message=str(e),
|
||||||
|
http_status_code="500",
|
||||||
|
details=get_exception_traceback(),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
async def Embed(
|
||||||
|
self,
|
||||||
|
request: sglang_scheduler_pb2.EmbedRequest,
|
||||||
|
context: grpc.aio.ServicerContext,
|
||||||
|
) -> sglang_scheduler_pb2.EmbedResponse:
|
||||||
|
"""Handle embedding requests."""
|
||||||
|
logger.info(f"Embedding request: {request.request_id}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Convert request
|
||||||
|
tokenized_req = self._convert_embed_request(request)
|
||||||
|
|
||||||
|
# Submit to request manager
|
||||||
|
future = await self.request_manager.embedding_request(
|
||||||
|
obj=tokenized_req,
|
||||||
|
request_id=request.request_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Wait for result
|
||||||
|
result = await future
|
||||||
|
|
||||||
|
# Create response
|
||||||
|
return sglang_scheduler_pb2.EmbedResponse(
|
||||||
|
request_id=request.request_id,
|
||||||
|
complete=sglang_scheduler_pb2.EmbedComplete(
|
||||||
|
embedding=result["embedding"],
|
||||||
|
prompt_tokens=result.get("prompt_tokens", 0),
|
||||||
|
cached_tokens=0,
|
||||||
|
embedding_dim=len(result["embedding"]),
|
||||||
|
generation_time=time.time() - self.start_time,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Embed failed: {e}\n{get_exception_traceback()}")
|
||||||
|
return sglang_scheduler_pb2.EmbedResponse(
|
||||||
|
request_id=request.request_id,
|
||||||
|
error=sglang_scheduler_pb2.EmbedError(
|
||||||
|
message=str(e),
|
||||||
|
code="INTERNAL_ERROR",
|
||||||
|
details=get_exception_traceback(),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
async def HealthCheck(
|
||||||
|
self,
|
||||||
|
request: sglang_scheduler_pb2.HealthCheckRequest,
|
||||||
|
context: grpc.aio.ServicerContext,
|
||||||
|
) -> sglang_scheduler_pb2.HealthCheckResponse:
|
||||||
|
"""Health check by generating from client input."""
|
||||||
|
try:
|
||||||
|
# Check if request manager is shutting down
|
||||||
|
if self.request_manager.gracefully_exit:
|
||||||
|
return sglang_scheduler_pb2.HealthCheckResponse(
|
||||||
|
healthy=False, message="Server shutting down"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Extract tokenized input from request
|
||||||
|
if not request.HasField("tokenized"):
|
||||||
|
return sglang_scheduler_pb2.HealthCheckResponse(
|
||||||
|
healthy=False, message="Tokenized input required for health check"
|
||||||
|
)
|
||||||
|
|
||||||
|
input_text = request.tokenized.original_text
|
||||||
|
input_ids = list(request.tokenized.input_ids)
|
||||||
|
|
||||||
|
# Create health check request
|
||||||
|
rid = f"HEALTH_CHECK_GRPC_{time.time()}"
|
||||||
|
|
||||||
|
health_request = TokenizedGenerateReqInput(
|
||||||
|
rid=rid,
|
||||||
|
input_text=input_text,
|
||||||
|
input_ids=input_ids,
|
||||||
|
sampling_params=SGLSamplingParams(max_new_tokens=1, temperature=0.0),
|
||||||
|
stream=False,
|
||||||
|
mm_inputs=None,
|
||||||
|
return_logprob=False,
|
||||||
|
logprob_start_len=-1,
|
||||||
|
top_logprobs_num=0,
|
||||||
|
token_ids_logprob=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"Sending health check request to request manager...")
|
||||||
|
|
||||||
|
# Submit and wait for response
|
||||||
|
output_queue = await self.request_manager.generate_request(
|
||||||
|
health_request, request_id=rid
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Wait for response with configurable timeout
|
||||||
|
response = await asyncio.wait_for(
|
||||||
|
output_queue.get(), timeout=HEALTH_CHECK_TIMEOUT
|
||||||
|
)
|
||||||
|
|
||||||
|
# Clean up
|
||||||
|
if rid in self.request_manager.rid_to_state:
|
||||||
|
del self.request_manager.rid_to_state[rid]
|
||||||
|
|
||||||
|
return sglang_scheduler_pb2.HealthCheckResponse(
|
||||||
|
healthy=True, message="Health check passed"
|
||||||
|
)
|
||||||
|
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
# Clean up on timeout
|
||||||
|
if rid in self.request_manager.rid_to_state:
|
||||||
|
del self.request_manager.rid_to_state[rid]
|
||||||
|
|
||||||
|
return sglang_scheduler_pb2.HealthCheckResponse(
|
||||||
|
healthy=False, message="Health check timeout"
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Health check failed: {e}")
|
||||||
|
return sglang_scheduler_pb2.HealthCheckResponse(
|
||||||
|
healthy=False, message=f"Health check error: {str(e)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
async def Abort(
|
||||||
|
self,
|
||||||
|
request: sglang_scheduler_pb2.AbortRequest,
|
||||||
|
context: grpc.aio.ServicerContext,
|
||||||
|
) -> sglang_scheduler_pb2.AbortResponse:
|
||||||
|
"""Abort an ongoing request."""
|
||||||
|
logger.info(f"Aborting request: {request.request_id}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
success = await self.request_manager.abort_request(request.request_id)
|
||||||
|
|
||||||
|
return sglang_scheduler_pb2.AbortResponse(
|
||||||
|
success=success,
|
||||||
|
message=f"Request {request.request_id} {'aborted' if success else 'not found'}",
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Abort failed: {e}")
|
||||||
|
return sglang_scheduler_pb2.AbortResponse(
|
||||||
|
success=False,
|
||||||
|
message=str(e),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Helper methods for request/response conversion
|
||||||
|
|
||||||
|
def _convert_generate_request(
|
||||||
|
self, grpc_req: sglang_scheduler_pb2.GenerateRequest
|
||||||
|
) -> TokenizedGenerateReqInput:
|
||||||
|
"""Convert gRPC GenerateRequest to internal format."""
|
||||||
|
|
||||||
|
# Extract tokenized input
|
||||||
|
if not grpc_req.HasField("tokenized"):
|
||||||
|
raise ValueError("Tokenized input must be provided")
|
||||||
|
|
||||||
|
input_text = grpc_req.tokenized.original_text
|
||||||
|
input_ids = list(grpc_req.tokenized.input_ids)
|
||||||
|
|
||||||
|
# Convert sampling params
|
||||||
|
sampling_params = self._convert_sampling_params(grpc_req.sampling_params)
|
||||||
|
|
||||||
|
# Create request
|
||||||
|
return TokenizedGenerateReqInput(
|
||||||
|
rid=grpc_req.request_id,
|
||||||
|
input_text=input_text,
|
||||||
|
input_ids=input_ids,
|
||||||
|
mm_inputs=None, # TODO: implement mm support
|
||||||
|
sampling_params=sampling_params,
|
||||||
|
return_logprob=grpc_req.return_logprob,
|
||||||
|
logprob_start_len=grpc_req.logprob_start_len or -1,
|
||||||
|
top_logprobs_num=grpc_req.top_logprobs_num or 0,
|
||||||
|
stream=True, # Always stream for gRPC
|
||||||
|
lora_path=grpc_req.lora_id if grpc_req.lora_id else None,
|
||||||
|
token_ids_logprob=(
|
||||||
|
list(grpc_req.token_ids_logprob) if grpc_req.token_ids_logprob else None
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _convert_embed_request(
|
||||||
|
self, grpc_req: sglang_scheduler_pb2.EmbedRequest
|
||||||
|
) -> TokenizedEmbeddingReqInput:
|
||||||
|
"""Convert gRPC EmbedRequest to internal format."""
|
||||||
|
|
||||||
|
# Extract tokenized input
|
||||||
|
if not grpc_req.HasField("tokenized"):
|
||||||
|
raise ValueError("Tokenized input must be provided")
|
||||||
|
|
||||||
|
input_text = grpc_req.tokenized.original_text
|
||||||
|
input_ids = list(grpc_req.tokenized.input_ids)
|
||||||
|
|
||||||
|
return TokenizedEmbeddingReqInput(
|
||||||
|
rid=grpc_req.request_id,
|
||||||
|
input_text=input_text,
|
||||||
|
input_ids=input_ids,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _convert_sampling_params(
|
||||||
|
self, grpc_params: sglang_scheduler_pb2.SamplingParams
|
||||||
|
) -> SGLSamplingParams:
|
||||||
|
"""Convert gRPC SamplingParams to internal format."""
|
||||||
|
|
||||||
|
# Handle constraint types
|
||||||
|
regex = None
|
||||||
|
json_schema = None
|
||||||
|
ebnf_grammar = None
|
||||||
|
|
||||||
|
if grpc_params.HasField("regex"):
|
||||||
|
regex = grpc_params.regex
|
||||||
|
elif grpc_params.HasField("json_schema"):
|
||||||
|
json_schema = grpc_params.json_schema
|
||||||
|
elif grpc_params.HasField("ebnf_grammar"):
|
||||||
|
ebnf_grammar = grpc_params.ebnf_grammar
|
||||||
|
|
||||||
|
return SGLSamplingParams(
|
||||||
|
temperature=grpc_params.temperature or 1.0,
|
||||||
|
top_p=grpc_params.top_p or 1.0,
|
||||||
|
top_k=grpc_params.top_k or -1,
|
||||||
|
min_p=grpc_params.min_p or 0.0,
|
||||||
|
frequency_penalty=grpc_params.frequency_penalty or 0.0,
|
||||||
|
presence_penalty=grpc_params.presence_penalty or 0.0,
|
||||||
|
repetition_penalty=grpc_params.repetition_penalty or 1.0,
|
||||||
|
max_new_tokens=grpc_params.max_new_tokens or 128,
|
||||||
|
min_new_tokens=grpc_params.min_new_tokens or 0,
|
||||||
|
stop=list(grpc_params.stop) if grpc_params.stop else None,
|
||||||
|
stop_token_ids=(
|
||||||
|
list(grpc_params.stop_token_ids) if grpc_params.stop_token_ids else None
|
||||||
|
),
|
||||||
|
skip_special_tokens=grpc_params.skip_special_tokens,
|
||||||
|
spaces_between_special_tokens=grpc_params.spaces_between_special_tokens,
|
||||||
|
regex=regex,
|
||||||
|
json_schema=json_schema,
|
||||||
|
ebnf=ebnf_grammar,
|
||||||
|
n=grpc_params.n or 1,
|
||||||
|
ignore_eos=grpc_params.ignore_eos,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _create_chunk_response(
|
||||||
|
self, request_id: str, output: Dict
|
||||||
|
) -> sglang_scheduler_pb2.GenerateResponse:
|
||||||
|
"""Create a streaming chunk response."""
|
||||||
|
return sglang_scheduler_pb2.GenerateResponse(
|
||||||
|
request_id=request_id,
|
||||||
|
chunk=sglang_scheduler_pb2.GenerateStreamChunk(
|
||||||
|
token_id=output["token_ids"][-1] if output.get("token_ids") else 0,
|
||||||
|
text=output.get("text", ""),
|
||||||
|
prompt_tokens=0,
|
||||||
|
completion_tokens=len(output.get("token_ids", [])),
|
||||||
|
cached_tokens=0,
|
||||||
|
generation_time=time.time() - self.start_time,
|
||||||
|
queue_time=0.0,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _create_completion_response(
|
||||||
|
self, request_id: str, output: Dict
|
||||||
|
) -> sglang_scheduler_pb2.GenerateResponse:
|
||||||
|
"""Create a completion response."""
|
||||||
|
|
||||||
|
# Determine finish reason
|
||||||
|
finish_reason = sglang_scheduler_pb2.GenerateComplete.STOP
|
||||||
|
meta_info = output.get("meta_info", {})
|
||||||
|
if meta_info.get("finish_reason") == "length":
|
||||||
|
finish_reason = sglang_scheduler_pb2.GenerateComplete.LENGTH
|
||||||
|
elif meta_info.get("finish_reason") == "eos_token":
|
||||||
|
finish_reason = sglang_scheduler_pb2.GenerateComplete.EOS_TOKEN
|
||||||
|
|
||||||
|
return sglang_scheduler_pb2.GenerateResponse(
|
||||||
|
request_id=request_id,
|
||||||
|
complete=sglang_scheduler_pb2.GenerateComplete(
|
||||||
|
output_ids=output.get("token_ids", []),
|
||||||
|
output_text=output.get("text", ""),
|
||||||
|
finish_reason=finish_reason,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
async def shutdown(self):
|
||||||
|
"""Shutdown the service."""
|
||||||
|
logger.info("Shutting down gRPC service")
|
||||||
|
|
||||||
|
# Shutdown request manager (handles its own tasks)
|
||||||
|
await self.request_manager.shutdown()
|
||||||
|
|
||||||
|
|
||||||
|
async def serve_grpc(
|
||||||
|
server_args: ServerArgs,
|
||||||
|
model_info: Optional[Dict] = None,
|
||||||
|
):
|
||||||
|
"""Start the standalone gRPC server with integrated scheduler."""
|
||||||
|
|
||||||
|
# Launch only the scheduler process(es) (no tokenizer/detokenizer needed for gRPC)
|
||||||
|
logger.info("Launching scheduler process(es)...")
|
||||||
|
scheduler_info, port_args, scheduler_procs = _launch_scheduler_process_only(
|
||||||
|
server_args=server_args,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update model info from scheduler info
|
||||||
|
if model_info is None:
|
||||||
|
model_info = {
|
||||||
|
"model_name": server_args.model_path,
|
||||||
|
"max_context_length": scheduler_info.get(
|
||||||
|
"max_total_num_tokens", server_args.context_length or 8192
|
||||||
|
),
|
||||||
|
"vocab_size": scheduler_info.get("vocab_size", 128256),
|
||||||
|
"supports_vision": scheduler_info.get("supports_vision", False),
|
||||||
|
"model_type": scheduler_info.get("model_type", "transformer"),
|
||||||
|
"max_req_input_len": scheduler_info.get("max_req_input_len", 8192),
|
||||||
|
"eos_token_ids": scheduler_info.get("eos_token_ids", []),
|
||||||
|
"pad_token_id": scheduler_info.get("pad_token_id", 0),
|
||||||
|
"bos_token_id": scheduler_info.get("bos_token_id", 1),
|
||||||
|
}
|
||||||
|
|
||||||
|
# Create request manager with the correct port args
|
||||||
|
request_manager = GrpcRequestManager(
|
||||||
|
server_args=server_args,
|
||||||
|
port_args=port_args,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create gRPC server
|
||||||
|
server = grpc.aio.server(
|
||||||
|
futures.ThreadPoolExecutor(max_workers=10),
|
||||||
|
options=[
|
||||||
|
("grpc.max_send_message_length", 1024 * 1024 * 256),
|
||||||
|
("grpc.max_receive_message_length", 1024 * 1024 * 256),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add service
|
||||||
|
servicer = SGLangSchedulerServicer(
|
||||||
|
request_manager=request_manager,
|
||||||
|
server_args=server_args,
|
||||||
|
model_info=model_info,
|
||||||
|
)
|
||||||
|
sglang_scheduler_pb2_grpc.add_SglangSchedulerServicer_to_server(servicer, server)
|
||||||
|
|
||||||
|
# Enable reflection
|
||||||
|
SERVICE_NAMES = (
|
||||||
|
sglang_scheduler_pb2.DESCRIPTOR.services_by_name["SglangScheduler"].full_name,
|
||||||
|
reflection.SERVICE_NAME,
|
||||||
|
)
|
||||||
|
reflection.enable_server_reflection(SERVICE_NAMES, server)
|
||||||
|
|
||||||
|
# Start server
|
||||||
|
listen_addr = f"{server_args.host}:{server_args.port}"
|
||||||
|
server.add_insecure_port(listen_addr)
|
||||||
|
|
||||||
|
logger.info(f"Starting standalone gRPC server on {listen_addr}")
|
||||||
|
|
||||||
|
await server.start()
|
||||||
|
|
||||||
|
# Handle shutdown signals
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
stop_event = asyncio.Event()
|
||||||
|
|
||||||
|
def signal_handler():
|
||||||
|
logger.info("Received shutdown signal")
|
||||||
|
stop_event.set()
|
||||||
|
|
||||||
|
for sig in (signal.SIGTERM, signal.SIGINT):
|
||||||
|
loop.add_signal_handler(sig, signal_handler)
|
||||||
|
|
||||||
|
try:
|
||||||
|
await stop_event.wait()
|
||||||
|
finally:
|
||||||
|
logger.info("Shutting down gRPC server")
|
||||||
|
await servicer.shutdown()
|
||||||
|
await server.stop(5.0)
|
||||||
|
|
||||||
|
# Terminate scheduler processes
|
||||||
|
for i, proc in enumerate(scheduler_procs):
|
||||||
|
if proc and proc.is_alive():
|
||||||
|
logger.info(f"Terminating scheduler process {i}...")
|
||||||
|
proc.terminate()
|
||||||
|
proc.join(timeout=5.0)
|
||||||
|
if proc.is_alive():
|
||||||
|
logger.warning(f"Force killing scheduler process {i}...")
|
||||||
|
proc.kill()
|
||||||
|
proc.join()
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""Main entry point for standalone gRPC server."""
|
||||||
|
# Fix CUDA multiprocessing issues - must be called before any CUDA operations
|
||||||
|
mp.set_start_method("spawn", force=True)
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description="SGLang Standalone gRPC Server")
|
||||||
|
|
||||||
|
# Server arguments
|
||||||
|
parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to bind to")
|
||||||
|
parser.add_argument("--port", type=int, default=30000, help="gRPC server port")
|
||||||
|
|
||||||
|
# Model arguments
|
||||||
|
parser.add_argument("--model-path", type=str, required=True, help="Model path")
|
||||||
|
parser.add_argument("--tokenizer-path", type=str, help="Tokenizer path")
|
||||||
|
parser.add_argument("--context-length", type=int, help="Context length")
|
||||||
|
parser.add_argument("--tp-size", type=int, default=1, help="Tensor parallel size")
|
||||||
|
parser.add_argument("--dp-size", type=int, default=1, help="Data parallel size")
|
||||||
|
|
||||||
|
# Runtime arguments
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-running-requests", type=int, default=2048, help="Max concurrent requests"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-total-tokens", type=int, default=1000000, help="Max total tokens"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-prefill-tokens", type=int, default=16384, help="Max prefill tokens"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--attention-backend", type=str, default="flashinfer", help="Attention backend"
|
||||||
|
)
|
||||||
|
parser.add_argument("--lora-paths", type=str, help="LoRA adapter paths")
|
||||||
|
|
||||||
|
# Logging
|
||||||
|
parser.add_argument("--log-level", type=str, default="INFO", help="Logging level")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# Convert to ServerArgs with gRPC host and port
|
||||||
|
server_args = ServerArgs(
|
||||||
|
model_path=args.model_path,
|
||||||
|
tokenizer_path=args.tokenizer_path or args.model_path,
|
||||||
|
context_length=args.context_length,
|
||||||
|
tp_size=args.tp_size,
|
||||||
|
dp_size=args.dp_size,
|
||||||
|
max_running_requests=args.max_running_requests,
|
||||||
|
max_total_tokens=args.max_total_tokens,
|
||||||
|
max_prefill_tokens=args.max_prefill_tokens,
|
||||||
|
attention_backend=args.attention_backend,
|
||||||
|
lora_paths=args.lora_paths.split(",") if args.lora_paths else None,
|
||||||
|
log_level=args.log_level,
|
||||||
|
# Override with gRPC server host and port
|
||||||
|
host=args.host,
|
||||||
|
port=args.port,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Run server
|
||||||
|
asyncio.run(
|
||||||
|
serve_grpc(
|
||||||
|
server_args=server_args,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
1
python/sglang/srt/grpc/__init__.py
Normal file
1
python/sglang/srt/grpc/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
# SGLang gRPC module
|
||||||
389
python/sglang/srt/grpc/sglang_scheduler.proto
Normal file
389
python/sglang/srt/grpc/sglang_scheduler.proto
Normal file
@@ -0,0 +1,389 @@
|
|||||||
|
syntax = "proto3";
|
||||||
|
|
||||||
|
package sglang.grpc.scheduler;
|
||||||
|
|
||||||
|
import "google/protobuf/timestamp.proto";
|
||||||
|
import "google/protobuf/struct.proto";
|
||||||
|
|
||||||
|
// Service definition for SGLang scheduler communication
|
||||||
|
// This protocol bridges the Rust router and Python scheduler
|
||||||
|
service SglangScheduler {
|
||||||
|
// Submit a generation request (supports streaming)
|
||||||
|
rpc Generate(GenerateRequest) returns (stream GenerateResponse);
|
||||||
|
|
||||||
|
// Submit an embedding request
|
||||||
|
rpc Embed(EmbedRequest) returns (EmbedResponse);
|
||||||
|
|
||||||
|
// Health check and metrics
|
||||||
|
rpc HealthCheck(HealthCheckRequest) returns (HealthCheckResponse);
|
||||||
|
|
||||||
|
// Abort a running request
|
||||||
|
rpc Abort(AbortRequest) returns (AbortResponse);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
// =====================
|
||||||
|
// Common Types
|
||||||
|
// =====================
|
||||||
|
|
||||||
|
// Sampling parameters matching SGLang's SamplingParams
|
||||||
|
message SamplingParams {
|
||||||
|
float temperature = 1;
|
||||||
|
float top_p = 2;
|
||||||
|
int32 top_k = 3;
|
||||||
|
float min_p = 4;
|
||||||
|
float frequency_penalty = 5;
|
||||||
|
float presence_penalty = 6;
|
||||||
|
float repetition_penalty = 7;
|
||||||
|
|
||||||
|
int32 max_new_tokens = 8;
|
||||||
|
repeated string stop = 9;
|
||||||
|
repeated int32 stop_token_ids = 10;
|
||||||
|
bool skip_special_tokens = 11;
|
||||||
|
bool spaces_between_special_tokens = 12;
|
||||||
|
|
||||||
|
// Structured generation
|
||||||
|
oneof constraint {
|
||||||
|
string regex = 13;
|
||||||
|
string json_schema = 14;
|
||||||
|
string ebnf_grammar = 15;
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoRA adapter
|
||||||
|
string lora_path = 16;
|
||||||
|
|
||||||
|
// Speculative decoding
|
||||||
|
int32 n = 17; // Number of samples
|
||||||
|
|
||||||
|
// Token healing
|
||||||
|
bool token_healing = 18;
|
||||||
|
|
||||||
|
// Additional parameters
|
||||||
|
int32 min_new_tokens = 19;
|
||||||
|
bool ignore_eos = 20;
|
||||||
|
bool no_stop_trim = 21;
|
||||||
|
int32 stream_interval = 22;
|
||||||
|
map<string, float> logit_bias = 23;
|
||||||
|
string structural_tag = 24;
|
||||||
|
|
||||||
|
// Custom parameters for extensibility
|
||||||
|
google.protobuf.Struct custom_params = 25;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// Disaggregated serving parameters
|
||||||
|
message DisaggregatedParams {
|
||||||
|
string bootstrap_host = 1;
|
||||||
|
int32 bootstrap_port = 2;
|
||||||
|
int32 bootstrap_room = 3;
|
||||||
|
}
|
||||||
|
|
||||||
|
// =====================
|
||||||
|
// Generate Request
|
||||||
|
// =====================
|
||||||
|
|
||||||
|
message GenerateRequest {
|
||||||
|
string request_id = 1;
|
||||||
|
|
||||||
|
// Input must be tokenized (no raw text)
|
||||||
|
TokenizedInput tokenized = 2;
|
||||||
|
|
||||||
|
// Multimodal inputs
|
||||||
|
MultimodalInputs mm_inputs = 3;
|
||||||
|
|
||||||
|
// Generation parameters
|
||||||
|
SamplingParams sampling_params = 4;
|
||||||
|
|
||||||
|
// Return options
|
||||||
|
bool return_logprob = 5;
|
||||||
|
int32 logprob_start_len = 6;
|
||||||
|
int32 top_logprobs_num = 7;
|
||||||
|
repeated int32 token_ids_logprob = 8;
|
||||||
|
bool return_hidden_states = 9;
|
||||||
|
|
||||||
|
// For disaggregated serving
|
||||||
|
DisaggregatedParams disaggregated_params = 10;
|
||||||
|
|
||||||
|
// Custom logit processor (serialized)
|
||||||
|
string custom_logit_processor = 11;
|
||||||
|
|
||||||
|
// Request metadata
|
||||||
|
google.protobuf.Timestamp timestamp = 12;
|
||||||
|
bool log_metrics = 13;
|
||||||
|
|
||||||
|
// Input embeddings (alternative to text/tokens)
|
||||||
|
repeated float input_embeds = 14;
|
||||||
|
|
||||||
|
// LoRA adapter ID (if pre-loaded)
|
||||||
|
string lora_id = 15;
|
||||||
|
|
||||||
|
// Data parallel routing
|
||||||
|
int32 data_parallel_rank = 16;
|
||||||
|
|
||||||
|
// For load balancing
|
||||||
|
int32 dp_balance_id = 17;
|
||||||
|
}
|
||||||
|
|
||||||
|
message TokenizedInput {
|
||||||
|
string original_text = 1; // For reference
|
||||||
|
repeated int32 input_ids = 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
message MultimodalInputs {
|
||||||
|
// Simplified multimodal handling - actual data processed by tokenizer
|
||||||
|
repeated string image_urls = 1;
|
||||||
|
repeated string video_urls = 2;
|
||||||
|
repeated string audio_urls = 3;
|
||||||
|
|
||||||
|
// Pre-processed multimodal features (if available)
|
||||||
|
google.protobuf.Struct processed_features = 4;
|
||||||
|
|
||||||
|
// Raw data for direct processing
|
||||||
|
repeated bytes image_data = 5;
|
||||||
|
repeated bytes video_data = 6;
|
||||||
|
repeated bytes audio_data = 7;
|
||||||
|
|
||||||
|
// Modality metadata
|
||||||
|
repeated string modalities = 8;
|
||||||
|
}
|
||||||
|
|
||||||
|
// =====================
|
||||||
|
// Generate Response
|
||||||
|
// =====================
|
||||||
|
|
||||||
|
message GenerateResponse {
|
||||||
|
string request_id = 1;
|
||||||
|
|
||||||
|
// Response type
|
||||||
|
oneof response {
|
||||||
|
GenerateStreamChunk chunk = 2;
|
||||||
|
GenerateComplete complete = 3;
|
||||||
|
GenerateError error = 4;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
message GenerateStreamChunk {
|
||||||
|
// Generated token
|
||||||
|
int32 token_id = 1;
|
||||||
|
string text = 2;
|
||||||
|
|
||||||
|
// Cumulative counts
|
||||||
|
int32 prompt_tokens = 3;
|
||||||
|
int32 completion_tokens = 4;
|
||||||
|
int32 cached_tokens = 5;
|
||||||
|
|
||||||
|
// Logprobs (if requested)
|
||||||
|
LogProbs logprobs = 6;
|
||||||
|
|
||||||
|
// Hidden states (if requested)
|
||||||
|
repeated float hidden_states = 7;
|
||||||
|
|
||||||
|
// Metadata
|
||||||
|
float generation_time = 8; // Time to generate this token
|
||||||
|
int32 queue_time = 9; // Time spent in queue
|
||||||
|
}
|
||||||
|
|
||||||
|
message GenerateComplete {
|
||||||
|
// Final output
|
||||||
|
repeated int32 output_ids = 1;
|
||||||
|
string output_text = 2;
|
||||||
|
|
||||||
|
// Finish reason
|
||||||
|
enum FinishReason {
|
||||||
|
// The model generated a stop sequence.
|
||||||
|
STOP = 0;
|
||||||
|
// The model reached the maximum generation length.
|
||||||
|
LENGTH = 1;
|
||||||
|
// The model generated an end-of-sequence (EOS) token.
|
||||||
|
EOS_TOKEN = 2;
|
||||||
|
// The model generated a user-provided stop string.
|
||||||
|
STOP_STR = 3;
|
||||||
|
// The request was aborted by the user or system.
|
||||||
|
ABORT = 4;
|
||||||
|
}
|
||||||
|
FinishReason finish_reason = 3;
|
||||||
|
|
||||||
|
// All logprobs if requested
|
||||||
|
repeated LogProbs all_logprobs = 11;
|
||||||
|
|
||||||
|
// All hidden states if requested
|
||||||
|
repeated HiddenStates all_hidden_states = 12;
|
||||||
|
}
|
||||||
|
|
||||||
|
message GenerateError {
|
||||||
|
string message = 1;
|
||||||
|
string http_status_code = 2;
|
||||||
|
string details = 3;
|
||||||
|
}
|
||||||
|
|
||||||
|
message LogProbs {
|
||||||
|
repeated float token_logprobs = 1;
|
||||||
|
repeated int32 token_ids = 2;
|
||||||
|
|
||||||
|
// Top logprobs at each position
|
||||||
|
repeated TopLogProbs top_logprobs = 3;
|
||||||
|
|
||||||
|
// Decoded text for tokens
|
||||||
|
repeated string token_texts = 4;
|
||||||
|
}
|
||||||
|
|
||||||
|
message TopLogProbs {
|
||||||
|
repeated float values = 1;
|
||||||
|
repeated int32 token_ids = 2;
|
||||||
|
repeated string token_texts = 3;
|
||||||
|
}
|
||||||
|
|
||||||
|
message HiddenStates {
|
||||||
|
repeated float values = 1;
|
||||||
|
int32 layer = 2;
|
||||||
|
int32 position = 3;
|
||||||
|
}
|
||||||
|
|
||||||
|
// =====================
|
||||||
|
// Embedding Request
|
||||||
|
// =====================
|
||||||
|
|
||||||
|
message EmbedRequest {
|
||||||
|
string request_id = 1;
|
||||||
|
|
||||||
|
// Input must be tokenized (no raw text)
|
||||||
|
TokenizedInput tokenized = 2;
|
||||||
|
|
||||||
|
// Multimodal inputs
|
||||||
|
MultimodalInputs mm_inputs = 4;
|
||||||
|
|
||||||
|
// Dummy sampling params for compatibility
|
||||||
|
// EmbedRequest doesn't use sampling_params
|
||||||
|
SamplingParams sampling_params = 5;
|
||||||
|
|
||||||
|
bool log_metrics = 6;
|
||||||
|
|
||||||
|
// Token type IDs for models that require them
|
||||||
|
repeated int32 token_type_ids = 7;
|
||||||
|
|
||||||
|
// Data parallel routing
|
||||||
|
int32 data_parallel_rank = 8;
|
||||||
|
|
||||||
|
// For cross-encoder requests
|
||||||
|
bool is_cross_encoder = 9;
|
||||||
|
repeated string texts = 10; // For cross-encoder batch
|
||||||
|
}
|
||||||
|
|
||||||
|
message EmbedResponse {
|
||||||
|
string request_id = 1;
|
||||||
|
|
||||||
|
oneof response {
|
||||||
|
EmbedComplete complete = 2;
|
||||||
|
EmbedError error = 3;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
message EmbedComplete {
|
||||||
|
repeated float embedding = 1;
|
||||||
|
int32 prompt_tokens = 2;
|
||||||
|
int32 cached_tokens = 3;
|
||||||
|
|
||||||
|
// Additional metadata
|
||||||
|
int32 embedding_dim = 4;
|
||||||
|
float generation_time = 5;
|
||||||
|
|
||||||
|
// For batch embeddings
|
||||||
|
repeated Embedding batch_embeddings = 6;
|
||||||
|
}
|
||||||
|
|
||||||
|
message Embedding {
|
||||||
|
repeated float values = 1;
|
||||||
|
int32 index = 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
message EmbedError {
|
||||||
|
string message = 1;
|
||||||
|
string code = 2;
|
||||||
|
string details = 3;
|
||||||
|
}
|
||||||
|
|
||||||
|
// =====================
|
||||||
|
// Management Operations
|
||||||
|
// =====================
|
||||||
|
|
||||||
|
message HealthCheckRequest {
|
||||||
|
// Input for health test generation (must be tokenized)
|
||||||
|
TokenizedInput tokenized = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
message HealthCheckResponse {
|
||||||
|
bool healthy = 1;
|
||||||
|
string message = 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
message AbortRequest {
|
||||||
|
string request_id = 1;
|
||||||
|
string reason = 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
message AbortResponse {
|
||||||
|
bool success = 1;
|
||||||
|
string message = 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// =====================
|
||||||
|
// Additional Operations (Future)
|
||||||
|
// =====================
|
||||||
|
|
||||||
|
// Load LoRA adapter
|
||||||
|
message LoadLoRARequest {
|
||||||
|
string adapter_id = 1;
|
||||||
|
string adapter_path = 2;
|
||||||
|
int32 rank = 3;
|
||||||
|
}
|
||||||
|
|
||||||
|
message LoadLoRAResponse {
|
||||||
|
bool success = 1;
|
||||||
|
string adapter_id = 2;
|
||||||
|
string message = 3;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unload LoRA adapter
|
||||||
|
message UnloadLoRARequest {
|
||||||
|
string adapter_id = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
message UnloadLoRAResponse {
|
||||||
|
bool success = 1;
|
||||||
|
string message = 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update weights
|
||||||
|
message UpdateWeightsRequest {
|
||||||
|
oneof source {
|
||||||
|
string disk_path = 1;
|
||||||
|
bytes tensor_data = 2;
|
||||||
|
string remote_url = 3;
|
||||||
|
}
|
||||||
|
string weight_name = 4;
|
||||||
|
}
|
||||||
|
|
||||||
|
message UpdateWeightsResponse {
|
||||||
|
bool success = 1;
|
||||||
|
string message = 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get internal state for debugging
|
||||||
|
message GetInternalStateRequest {
|
||||||
|
repeated string state_keys = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
message GetInternalStateResponse {
|
||||||
|
google.protobuf.Struct state = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set internal state for testing
|
||||||
|
message SetInternalStateRequest {
|
||||||
|
google.protobuf.Struct state = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
message SetInternalStateResponse {
|
||||||
|
bool success = 1;
|
||||||
|
string message = 2;
|
||||||
|
}
|
||||||
106
python/sglang/srt/grpc/sglang_scheduler_pb2.py
Normal file
106
python/sglang/srt/grpc/sglang_scheduler_pb2.py
Normal file
File diff suppressed because one or more lines are too long
427
python/sglang/srt/grpc/sglang_scheduler_pb2.pyi
Normal file
427
python/sglang/srt/grpc/sglang_scheduler_pb2.pyi
Normal file
@@ -0,0 +1,427 @@
|
|||||||
|
import datetime
|
||||||
|
|
||||||
|
from google.protobuf import timestamp_pb2 as _timestamp_pb2
|
||||||
|
from google.protobuf import struct_pb2 as _struct_pb2
|
||||||
|
from google.protobuf.internal import containers as _containers
|
||||||
|
from google.protobuf.internal import enum_type_wrapper as _enum_type_wrapper
|
||||||
|
from google.protobuf import descriptor as _descriptor
|
||||||
|
from google.protobuf import message as _message
|
||||||
|
from collections.abc import Iterable as _Iterable, Mapping as _Mapping
|
||||||
|
from typing import ClassVar as _ClassVar, Optional as _Optional, Union as _Union
|
||||||
|
|
||||||
|
DESCRIPTOR: _descriptor.FileDescriptor
|
||||||
|
|
||||||
|
class SamplingParams(_message.Message):
|
||||||
|
__slots__ = ("temperature", "top_p", "top_k", "min_p", "frequency_penalty", "presence_penalty", "repetition_penalty", "max_new_tokens", "stop", "stop_token_ids", "skip_special_tokens", "spaces_between_special_tokens", "regex", "json_schema", "ebnf_grammar", "lora_path", "n", "token_healing", "min_new_tokens", "ignore_eos", "no_stop_trim", "stream_interval", "logit_bias", "structural_tag", "custom_params")
|
||||||
|
class LogitBiasEntry(_message.Message):
|
||||||
|
__slots__ = ("key", "value")
|
||||||
|
KEY_FIELD_NUMBER: _ClassVar[int]
|
||||||
|
VALUE_FIELD_NUMBER: _ClassVar[int]
|
||||||
|
key: str
|
||||||
|
value: float
|
||||||
|
def __init__(self, key: _Optional[str] = ..., value: _Optional[float] = ...) -> None: ...
|
||||||
|
TEMPERATURE_FIELD_NUMBER: _ClassVar[int]
|
||||||
|
TOP_P_FIELD_NUMBER: _ClassVar[int]
|
||||||
|
TOP_K_FIELD_NUMBER: _ClassVar[int]
|
||||||
|
MIN_P_FIELD_NUMBER: _ClassVar[int]
|
||||||
|
FREQUENCY_PENALTY_FIELD_NUMBER: _ClassVar[int]
|
||||||
|
PRESENCE_PENALTY_FIELD_NUMBER: _ClassVar[int]
|
||||||
|
REPETITION_PENALTY_FIELD_NUMBER: _ClassVar[int]
|
||||||
|
MAX_NEW_TOKENS_FIELD_NUMBER: _ClassVar[int]
|
||||||
|
STOP_FIELD_NUMBER: _ClassVar[int]
|
||||||
|
STOP_TOKEN_IDS_FIELD_NUMBER: _ClassVar[int]
|
||||||
|
SKIP_SPECIAL_TOKENS_FIELD_NUMBER: _ClassVar[int]
|
||||||
|
SPACES_BETWEEN_SPECIAL_TOKENS_FIELD_NUMBER: _ClassVar[int]
|
||||||
|
REGEX_FIELD_NUMBER: _ClassVar[int]
|
||||||
|
JSON_SCHEMA_FIELD_NUMBER: _ClassVar[int]
|
||||||
|
EBNF_GRAMMAR_FIELD_NUMBER: _ClassVar[int]
|
||||||
|
LORA_PATH_FIELD_NUMBER: _ClassVar[int]
|
||||||
|
N_FIELD_NUMBER: _ClassVar[int]
|
||||||
|
TOKEN_HEALING_FIELD_NUMBER: _ClassVar[int]
|
||||||
|
MIN_NEW_TOKENS_FIELD_NUMBER: _ClassVar[int]
|
||||||
|
IGNORE_EOS_FIELD_NUMBER: _ClassVar[int]
|
||||||
|
NO_STOP_TRIM_FIELD_NUMBER: _ClassVar[int]
|
||||||
|
STREAM_INTERVAL_FIELD_NUMBER: _ClassVar[int]
|
||||||
|
LOGIT_BIAS_FIELD_NUMBER: _ClassVar[int]
|
||||||
|
STRUCTURAL_TAG_FIELD_NUMBER: _ClassVar[int]
|
||||||
|
CUSTOM_PARAMS_FIELD_NUMBER: _ClassVar[int]
|
||||||
|
temperature: float
|
||||||
|
top_p: float
|
||||||
|
top_k: int
|
||||||
|
min_p: float
|
||||||
|
frequency_penalty: float
|
||||||
|
presence_penalty: float
|
||||||
|
repetition_penalty: float
|
||||||
|
max_new_tokens: int
|
||||||
|
stop: _containers.RepeatedScalarFieldContainer[str]
|
||||||
|
stop_token_ids: _containers.RepeatedScalarFieldContainer[int]
|
||||||
|
skip_special_tokens: bool
|
||||||
|
spaces_between_special_tokens: bool
|
||||||
|
regex: str
|
||||||
|
json_schema: str
|
||||||
|
ebnf_grammar: str
|
||||||
|
lora_path: str
|
||||||
|
n: int
|
||||||
|
token_healing: bool
|
||||||
|
min_new_tokens: int
|
||||||
|
ignore_eos: bool
|
||||||
|
no_stop_trim: bool
|
||||||
|
stream_interval: int
|
||||||
|
logit_bias: _containers.ScalarMap[str, float]
|
||||||
|
structural_tag: str
|
||||||
|
custom_params: _struct_pb2.Struct
|
||||||
|
def __init__(self, temperature: _Optional[float] = ..., top_p: _Optional[float] = ..., top_k: _Optional[int] = ..., min_p: _Optional[float] = ..., frequency_penalty: _Optional[float] = ..., presence_penalty: _Optional[float] = ..., repetition_penalty: _Optional[float] = ..., max_new_tokens: _Optional[int] = ..., stop: _Optional[_Iterable[str]] = ..., stop_token_ids: _Optional[_Iterable[int]] = ..., skip_special_tokens: bool = ..., spaces_between_special_tokens: bool = ..., regex: _Optional[str] = ..., json_schema: _Optional[str] = ..., ebnf_grammar: _Optional[str] = ..., lora_path: _Optional[str] = ..., n: _Optional[int] = ..., token_healing: bool = ..., min_new_tokens: _Optional[int] = ..., ignore_eos: bool = ..., no_stop_trim: bool = ..., stream_interval: _Optional[int] = ..., logit_bias: _Optional[_Mapping[str, float]] = ..., structural_tag: _Optional[str] = ..., custom_params: _Optional[_Union[_struct_pb2.Struct, _Mapping]] = ...) -> None: ...
|
||||||
|
|
||||||
|
class DisaggregatedParams(_message.Message):
|
||||||
|
__slots__ = ("bootstrap_host", "bootstrap_port", "bootstrap_room")
|
||||||
|
BOOTSTRAP_HOST_FIELD_NUMBER: _ClassVar[int]
|
||||||
|
BOOTSTRAP_PORT_FIELD_NUMBER: _ClassVar[int]
|
||||||
|
BOOTSTRAP_ROOM_FIELD_NUMBER: _ClassVar[int]
|
||||||
|
bootstrap_host: str
|
||||||
|
bootstrap_port: int
|
||||||
|
bootstrap_room: int
|
||||||
|
def __init__(self, bootstrap_host: _Optional[str] = ..., bootstrap_port: _Optional[int] = ..., bootstrap_room: _Optional[int] = ...) -> None: ...
|
||||||
|
|
||||||
|
class GenerateRequest(_message.Message):
|
||||||
|
__slots__ = ("request_id", "tokenized", "mm_inputs", "sampling_params", "return_logprob", "logprob_start_len", "top_logprobs_num", "token_ids_logprob", "return_hidden_states", "disaggregated_params", "custom_logit_processor", "timestamp", "log_metrics", "input_embeds", "lora_id", "data_parallel_rank", "dp_balance_id")
|
||||||
|
REQUEST_ID_FIELD_NUMBER: _ClassVar[int]
|
||||||
|
TOKENIZED_FIELD_NUMBER: _ClassVar[int]
|
||||||
|
MM_INPUTS_FIELD_NUMBER: _ClassVar[int]
|
||||||
|
SAMPLING_PARAMS_FIELD_NUMBER: _ClassVar[int]
|
||||||
|
RETURN_LOGPROB_FIELD_NUMBER: _ClassVar[int]
|
||||||
|
LOGPROB_START_LEN_FIELD_NUMBER: _ClassVar[int]
|
||||||
|
TOP_LOGPROBS_NUM_FIELD_NUMBER: _ClassVar[int]
|
||||||
|
TOKEN_IDS_LOGPROB_FIELD_NUMBER: _ClassVar[int]
|
||||||
|
RETURN_HIDDEN_STATES_FIELD_NUMBER: _ClassVar[int]
|
||||||
|
DISAGGREGATED_PARAMS_FIELD_NUMBER: _ClassVar[int]
|
||||||
|
CUSTOM_LOGIT_PROCESSOR_FIELD_NUMBER: _ClassVar[int]
|
||||||
|
TIMESTAMP_FIELD_NUMBER: _ClassVar[int]
|
||||||
|
LOG_METRICS_FIELD_NUMBER: _ClassVar[int]
|
||||||
|
INPUT_EMBEDS_FIELD_NUMBER: _ClassVar[int]
|
||||||
|
LORA_ID_FIELD_NUMBER: _ClassVar[int]
|
||||||
|
DATA_PARALLEL_RANK_FIELD_NUMBER: _ClassVar[int]
|
||||||
|
DP_BALANCE_ID_FIELD_NUMBER: _ClassVar[int]
|
||||||
|
request_id: str
|
||||||
|
tokenized: TokenizedInput
|
||||||
|
mm_inputs: MultimodalInputs
|
||||||
|
sampling_params: SamplingParams
|
||||||
|
return_logprob: bool
|
||||||
|
logprob_start_len: int
|
||||||
|
top_logprobs_num: int
|
||||||
|
token_ids_logprob: _containers.RepeatedScalarFieldContainer[int]
|
||||||
|
return_hidden_states: bool
|
||||||
|
disaggregated_params: DisaggregatedParams
|
||||||
|
custom_logit_processor: str
|
||||||
|
timestamp: _timestamp_pb2.Timestamp
|
||||||
|
log_metrics: bool
|
||||||
|
input_embeds: _containers.RepeatedScalarFieldContainer[float]
|
||||||
|
lora_id: str
|
||||||
|
data_parallel_rank: int
|
||||||
|
dp_balance_id: int
|
||||||
|
def __init__(self, request_id: _Optional[str] = ..., tokenized: _Optional[_Union[TokenizedInput, _Mapping]] = ..., mm_inputs: _Optional[_Union[MultimodalInputs, _Mapping]] = ..., sampling_params: _Optional[_Union[SamplingParams, _Mapping]] = ..., return_logprob: bool = ..., logprob_start_len: _Optional[int] = ..., top_logprobs_num: _Optional[int] = ..., token_ids_logprob: _Optional[_Iterable[int]] = ..., return_hidden_states: bool = ..., disaggregated_params: _Optional[_Union[DisaggregatedParams, _Mapping]] = ..., custom_logit_processor: _Optional[str] = ..., timestamp: _Optional[_Union[datetime.datetime, _timestamp_pb2.Timestamp, _Mapping]] = ..., log_metrics: bool = ..., input_embeds: _Optional[_Iterable[float]] = ..., lora_id: _Optional[str] = ..., data_parallel_rank: _Optional[int] = ..., dp_balance_id: _Optional[int] = ...) -> None: ...
|
||||||
|
|
||||||
|
class TokenizedInput(_message.Message):
|
||||||
|
__slots__ = ("original_text", "input_ids")
|
||||||
|
ORIGINAL_TEXT_FIELD_NUMBER: _ClassVar[int]
|
||||||
|
INPUT_IDS_FIELD_NUMBER: _ClassVar[int]
|
||||||
|
original_text: str
|
||||||
|
input_ids: _containers.RepeatedScalarFieldContainer[int]
|
||||||
|
def __init__(self, original_text: _Optional[str] = ..., input_ids: _Optional[_Iterable[int]] = ...) -> None: ...
|
||||||
|
|
||||||
|
class MultimodalInputs(_message.Message):
|
||||||
|
__slots__ = ("image_urls", "video_urls", "audio_urls", "processed_features", "image_data", "video_data", "audio_data", "modalities")
|
||||||
|
IMAGE_URLS_FIELD_NUMBER: _ClassVar[int]
|
||||||
|
VIDEO_URLS_FIELD_NUMBER: _ClassVar[int]
|
||||||
|
AUDIO_URLS_FIELD_NUMBER: _ClassVar[int]
|
||||||
|
PROCESSED_FEATURES_FIELD_NUMBER: _ClassVar[int]
|
||||||
|
IMAGE_DATA_FIELD_NUMBER: _ClassVar[int]
|
||||||
|
VIDEO_DATA_FIELD_NUMBER: _ClassVar[int]
|
||||||
|
AUDIO_DATA_FIELD_NUMBER: _ClassVar[int]
|
||||||
|
MODALITIES_FIELD_NUMBER: _ClassVar[int]
|
||||||
|
image_urls: _containers.RepeatedScalarFieldContainer[str]
|
||||||
|
video_urls: _containers.RepeatedScalarFieldContainer[str]
|
||||||
|
audio_urls: _containers.RepeatedScalarFieldContainer[str]
|
||||||
|
processed_features: _struct_pb2.Struct
|
||||||
|
image_data: _containers.RepeatedScalarFieldContainer[bytes]
|
||||||
|
video_data: _containers.RepeatedScalarFieldContainer[bytes]
|
||||||
|
audio_data: _containers.RepeatedScalarFieldContainer[bytes]
|
||||||
|
modalities: _containers.RepeatedScalarFieldContainer[str]
|
||||||
|
def __init__(self, image_urls: _Optional[_Iterable[str]] = ..., video_urls: _Optional[_Iterable[str]] = ..., audio_urls: _Optional[_Iterable[str]] = ..., processed_features: _Optional[_Union[_struct_pb2.Struct, _Mapping]] = ..., image_data: _Optional[_Iterable[bytes]] = ..., video_data: _Optional[_Iterable[bytes]] = ..., audio_data: _Optional[_Iterable[bytes]] = ..., modalities: _Optional[_Iterable[str]] = ...) -> None: ...
|
||||||
|
|
||||||
|
class GenerateResponse(_message.Message):
|
||||||
|
__slots__ = ("request_id", "chunk", "complete", "error")
|
||||||
|
REQUEST_ID_FIELD_NUMBER: _ClassVar[int]
|
||||||
|
CHUNK_FIELD_NUMBER: _ClassVar[int]
|
||||||
|
COMPLETE_FIELD_NUMBER: _ClassVar[int]
|
||||||
|
ERROR_FIELD_NUMBER: _ClassVar[int]
|
||||||
|
request_id: str
|
||||||
|
chunk: GenerateStreamChunk
|
||||||
|
complete: GenerateComplete
|
||||||
|
error: GenerateError
|
||||||
|
def __init__(self, request_id: _Optional[str] = ..., chunk: _Optional[_Union[GenerateStreamChunk, _Mapping]] = ..., complete: _Optional[_Union[GenerateComplete, _Mapping]] = ..., error: _Optional[_Union[GenerateError, _Mapping]] = ...) -> None: ...
|
||||||
|
|
||||||
|
class GenerateStreamChunk(_message.Message):
|
||||||
|
__slots__ = ("token_id", "text", "prompt_tokens", "completion_tokens", "cached_tokens", "logprobs", "hidden_states", "generation_time", "queue_time")
|
||||||
|
TOKEN_ID_FIELD_NUMBER: _ClassVar[int]
|
||||||
|
TEXT_FIELD_NUMBER: _ClassVar[int]
|
||||||
|
PROMPT_TOKENS_FIELD_NUMBER: _ClassVar[int]
|
||||||
|
COMPLETION_TOKENS_FIELD_NUMBER: _ClassVar[int]
|
||||||
|
CACHED_TOKENS_FIELD_NUMBER: _ClassVar[int]
|
||||||
|
LOGPROBS_FIELD_NUMBER: _ClassVar[int]
|
||||||
|
HIDDEN_STATES_FIELD_NUMBER: _ClassVar[int]
|
||||||
|
GENERATION_TIME_FIELD_NUMBER: _ClassVar[int]
|
||||||
|
QUEUE_TIME_FIELD_NUMBER: _ClassVar[int]
|
||||||
|
token_id: int
|
||||||
|
text: str
|
||||||
|
prompt_tokens: int
|
||||||
|
completion_tokens: int
|
||||||
|
cached_tokens: int
|
||||||
|
logprobs: LogProbs
|
||||||
|
hidden_states: _containers.RepeatedScalarFieldContainer[float]
|
||||||
|
generation_time: float
|
||||||
|
queue_time: int
|
||||||
|
def __init__(self, token_id: _Optional[int] = ..., text: _Optional[str] = ..., prompt_tokens: _Optional[int] = ..., completion_tokens: _Optional[int] = ..., cached_tokens: _Optional[int] = ..., logprobs: _Optional[_Union[LogProbs, _Mapping]] = ..., hidden_states: _Optional[_Iterable[float]] = ..., generation_time: _Optional[float] = ..., queue_time: _Optional[int] = ...) -> None: ...
|
||||||
|
|
||||||
|
class GenerateComplete(_message.Message):
|
||||||
|
__slots__ = ("output_ids", "output_text", "finish_reason", "all_logprobs", "all_hidden_states")
|
||||||
|
class FinishReason(int, metaclass=_enum_type_wrapper.EnumTypeWrapper):
|
||||||
|
__slots__ = ()
|
||||||
|
STOP: _ClassVar[GenerateComplete.FinishReason]
|
||||||
|
LENGTH: _ClassVar[GenerateComplete.FinishReason]
|
||||||
|
EOS_TOKEN: _ClassVar[GenerateComplete.FinishReason]
|
||||||
|
STOP_STR: _ClassVar[GenerateComplete.FinishReason]
|
||||||
|
ABORT: _ClassVar[GenerateComplete.FinishReason]
|
||||||
|
STOP: GenerateComplete.FinishReason
|
||||||
|
LENGTH: GenerateComplete.FinishReason
|
||||||
|
EOS_TOKEN: GenerateComplete.FinishReason
|
||||||
|
STOP_STR: GenerateComplete.FinishReason
|
||||||
|
ABORT: GenerateComplete.FinishReason
|
||||||
|
OUTPUT_IDS_FIELD_NUMBER: _ClassVar[int]
|
||||||
|
OUTPUT_TEXT_FIELD_NUMBER: _ClassVar[int]
|
||||||
|
FINISH_REASON_FIELD_NUMBER: _ClassVar[int]
|
||||||
|
ALL_LOGPROBS_FIELD_NUMBER: _ClassVar[int]
|
||||||
|
ALL_HIDDEN_STATES_FIELD_NUMBER: _ClassVar[int]
|
||||||
|
output_ids: _containers.RepeatedScalarFieldContainer[int]
|
||||||
|
output_text: str
|
||||||
|
finish_reason: GenerateComplete.FinishReason
|
||||||
|
all_logprobs: _containers.RepeatedCompositeFieldContainer[LogProbs]
|
||||||
|
all_hidden_states: _containers.RepeatedCompositeFieldContainer[HiddenStates]
|
||||||
|
def __init__(self, output_ids: _Optional[_Iterable[int]] = ..., output_text: _Optional[str] = ..., finish_reason: _Optional[_Union[GenerateComplete.FinishReason, str]] = ..., all_logprobs: _Optional[_Iterable[_Union[LogProbs, _Mapping]]] = ..., all_hidden_states: _Optional[_Iterable[_Union[HiddenStates, _Mapping]]] = ...) -> None: ...
|
||||||
|
|
||||||
|
class GenerateError(_message.Message):
|
||||||
|
__slots__ = ("message", "http_status_code", "details")
|
||||||
|
MESSAGE_FIELD_NUMBER: _ClassVar[int]
|
||||||
|
HTTP_STATUS_CODE_FIELD_NUMBER: _ClassVar[int]
|
||||||
|
DETAILS_FIELD_NUMBER: _ClassVar[int]
|
||||||
|
message: str
|
||||||
|
http_status_code: str
|
||||||
|
details: str
|
||||||
|
def __init__(self, message: _Optional[str] = ..., http_status_code: _Optional[str] = ..., details: _Optional[str] = ...) -> None: ...
|
||||||
|
|
||||||
|
class LogProbs(_message.Message):
|
||||||
|
__slots__ = ("token_logprobs", "token_ids", "top_logprobs", "token_texts")
|
||||||
|
TOKEN_LOGPROBS_FIELD_NUMBER: _ClassVar[int]
|
||||||
|
TOKEN_IDS_FIELD_NUMBER: _ClassVar[int]
|
||||||
|
TOP_LOGPROBS_FIELD_NUMBER: _ClassVar[int]
|
||||||
|
TOKEN_TEXTS_FIELD_NUMBER: _ClassVar[int]
|
||||||
|
token_logprobs: _containers.RepeatedScalarFieldContainer[float]
|
||||||
|
token_ids: _containers.RepeatedScalarFieldContainer[int]
|
||||||
|
top_logprobs: _containers.RepeatedCompositeFieldContainer[TopLogProbs]
|
||||||
|
token_texts: _containers.RepeatedScalarFieldContainer[str]
|
||||||
|
def __init__(self, token_logprobs: _Optional[_Iterable[float]] = ..., token_ids: _Optional[_Iterable[int]] = ..., top_logprobs: _Optional[_Iterable[_Union[TopLogProbs, _Mapping]]] = ..., token_texts: _Optional[_Iterable[str]] = ...) -> None: ...
|
||||||
|
|
||||||
|
class TopLogProbs(_message.Message):
|
||||||
|
__slots__ = ("values", "token_ids", "token_texts")
|
||||||
|
VALUES_FIELD_NUMBER: _ClassVar[int]
|
||||||
|
TOKEN_IDS_FIELD_NUMBER: _ClassVar[int]
|
||||||
|
TOKEN_TEXTS_FIELD_NUMBER: _ClassVar[int]
|
||||||
|
values: _containers.RepeatedScalarFieldContainer[float]
|
||||||
|
token_ids: _containers.RepeatedScalarFieldContainer[int]
|
||||||
|
token_texts: _containers.RepeatedScalarFieldContainer[str]
|
||||||
|
def __init__(self, values: _Optional[_Iterable[float]] = ..., token_ids: _Optional[_Iterable[int]] = ..., token_texts: _Optional[_Iterable[str]] = ...) -> None: ...
|
||||||
|
|
||||||
|
class HiddenStates(_message.Message):
|
||||||
|
__slots__ = ("values", "layer", "position")
|
||||||
|
VALUES_FIELD_NUMBER: _ClassVar[int]
|
||||||
|
LAYER_FIELD_NUMBER: _ClassVar[int]
|
||||||
|
POSITION_FIELD_NUMBER: _ClassVar[int]
|
||||||
|
values: _containers.RepeatedScalarFieldContainer[float]
|
||||||
|
layer: int
|
||||||
|
position: int
|
||||||
|
def __init__(self, values: _Optional[_Iterable[float]] = ..., layer: _Optional[int] = ..., position: _Optional[int] = ...) -> None: ...
|
||||||
|
|
||||||
|
class EmbedRequest(_message.Message):
|
||||||
|
__slots__ = ("request_id", "tokenized", "mm_inputs", "sampling_params", "log_metrics", "token_type_ids", "data_parallel_rank", "is_cross_encoder", "texts")
|
||||||
|
REQUEST_ID_FIELD_NUMBER: _ClassVar[int]
|
||||||
|
TOKENIZED_FIELD_NUMBER: _ClassVar[int]
|
||||||
|
MM_INPUTS_FIELD_NUMBER: _ClassVar[int]
|
||||||
|
SAMPLING_PARAMS_FIELD_NUMBER: _ClassVar[int]
|
||||||
|
LOG_METRICS_FIELD_NUMBER: _ClassVar[int]
|
||||||
|
TOKEN_TYPE_IDS_FIELD_NUMBER: _ClassVar[int]
|
||||||
|
DATA_PARALLEL_RANK_FIELD_NUMBER: _ClassVar[int]
|
||||||
|
IS_CROSS_ENCODER_FIELD_NUMBER: _ClassVar[int]
|
||||||
|
TEXTS_FIELD_NUMBER: _ClassVar[int]
|
||||||
|
request_id: str
|
||||||
|
tokenized: TokenizedInput
|
||||||
|
mm_inputs: MultimodalInputs
|
||||||
|
sampling_params: SamplingParams
|
||||||
|
log_metrics: bool
|
||||||
|
token_type_ids: _containers.RepeatedScalarFieldContainer[int]
|
||||||
|
data_parallel_rank: int
|
||||||
|
is_cross_encoder: bool
|
||||||
|
texts: _containers.RepeatedScalarFieldContainer[str]
|
||||||
|
def __init__(self, request_id: _Optional[str] = ..., tokenized: _Optional[_Union[TokenizedInput, _Mapping]] = ..., mm_inputs: _Optional[_Union[MultimodalInputs, _Mapping]] = ..., sampling_params: _Optional[_Union[SamplingParams, _Mapping]] = ..., log_metrics: bool = ..., token_type_ids: _Optional[_Iterable[int]] = ..., data_parallel_rank: _Optional[int] = ..., is_cross_encoder: bool = ..., texts: _Optional[_Iterable[str]] = ...) -> None: ...
|
||||||
|
|
||||||
|
class EmbedResponse(_message.Message):
|
||||||
|
__slots__ = ("request_id", "complete", "error")
|
||||||
|
REQUEST_ID_FIELD_NUMBER: _ClassVar[int]
|
||||||
|
COMPLETE_FIELD_NUMBER: _ClassVar[int]
|
||||||
|
ERROR_FIELD_NUMBER: _ClassVar[int]
|
||||||
|
request_id: str
|
||||||
|
complete: EmbedComplete
|
||||||
|
error: EmbedError
|
||||||
|
def __init__(self, request_id: _Optional[str] = ..., complete: _Optional[_Union[EmbedComplete, _Mapping]] = ..., error: _Optional[_Union[EmbedError, _Mapping]] = ...) -> None: ...
|
||||||
|
|
||||||
|
class EmbedComplete(_message.Message):
|
||||||
|
__slots__ = ("embedding", "prompt_tokens", "cached_tokens", "embedding_dim", "generation_time", "batch_embeddings")
|
||||||
|
EMBEDDING_FIELD_NUMBER: _ClassVar[int]
|
||||||
|
PROMPT_TOKENS_FIELD_NUMBER: _ClassVar[int]
|
||||||
|
CACHED_TOKENS_FIELD_NUMBER: _ClassVar[int]
|
||||||
|
EMBEDDING_DIM_FIELD_NUMBER: _ClassVar[int]
|
||||||
|
GENERATION_TIME_FIELD_NUMBER: _ClassVar[int]
|
||||||
|
BATCH_EMBEDDINGS_FIELD_NUMBER: _ClassVar[int]
|
||||||
|
embedding: _containers.RepeatedScalarFieldContainer[float]
|
||||||
|
prompt_tokens: int
|
||||||
|
cached_tokens: int
|
||||||
|
embedding_dim: int
|
||||||
|
generation_time: float
|
||||||
|
batch_embeddings: _containers.RepeatedCompositeFieldContainer[Embedding]
|
||||||
|
def __init__(self, embedding: _Optional[_Iterable[float]] = ..., prompt_tokens: _Optional[int] = ..., cached_tokens: _Optional[int] = ..., embedding_dim: _Optional[int] = ..., generation_time: _Optional[float] = ..., batch_embeddings: _Optional[_Iterable[_Union[Embedding, _Mapping]]] = ...) -> None: ...
|
||||||
|
|
||||||
|
class Embedding(_message.Message):
|
||||||
|
__slots__ = ("values", "index")
|
||||||
|
VALUES_FIELD_NUMBER: _ClassVar[int]
|
||||||
|
INDEX_FIELD_NUMBER: _ClassVar[int]
|
||||||
|
values: _containers.RepeatedScalarFieldContainer[float]
|
||||||
|
index: int
|
||||||
|
def __init__(self, values: _Optional[_Iterable[float]] = ..., index: _Optional[int] = ...) -> None: ...
|
||||||
|
|
||||||
|
class EmbedError(_message.Message):
|
||||||
|
__slots__ = ("message", "code", "details")
|
||||||
|
MESSAGE_FIELD_NUMBER: _ClassVar[int]
|
||||||
|
CODE_FIELD_NUMBER: _ClassVar[int]
|
||||||
|
DETAILS_FIELD_NUMBER: _ClassVar[int]
|
||||||
|
message: str
|
||||||
|
code: str
|
||||||
|
details: str
|
||||||
|
def __init__(self, message: _Optional[str] = ..., code: _Optional[str] = ..., details: _Optional[str] = ...) -> None: ...
|
||||||
|
|
||||||
|
class HealthCheckRequest(_message.Message):
|
||||||
|
__slots__ = ("tokenized",)
|
||||||
|
TOKENIZED_FIELD_NUMBER: _ClassVar[int]
|
||||||
|
tokenized: TokenizedInput
|
||||||
|
def __init__(self, tokenized: _Optional[_Union[TokenizedInput, _Mapping]] = ...) -> None: ...
|
||||||
|
|
||||||
|
class HealthCheckResponse(_message.Message):
|
||||||
|
__slots__ = ("healthy", "message")
|
||||||
|
HEALTHY_FIELD_NUMBER: _ClassVar[int]
|
||||||
|
MESSAGE_FIELD_NUMBER: _ClassVar[int]
|
||||||
|
healthy: bool
|
||||||
|
message: str
|
||||||
|
def __init__(self, healthy: bool = ..., message: _Optional[str] = ...) -> None: ...
|
||||||
|
|
||||||
|
class AbortRequest(_message.Message):
|
||||||
|
__slots__ = ("request_id", "reason")
|
||||||
|
REQUEST_ID_FIELD_NUMBER: _ClassVar[int]
|
||||||
|
REASON_FIELD_NUMBER: _ClassVar[int]
|
||||||
|
request_id: str
|
||||||
|
reason: str
|
||||||
|
def __init__(self, request_id: _Optional[str] = ..., reason: _Optional[str] = ...) -> None: ...
|
||||||
|
|
||||||
|
class AbortResponse(_message.Message):
|
||||||
|
__slots__ = ("success", "message")
|
||||||
|
SUCCESS_FIELD_NUMBER: _ClassVar[int]
|
||||||
|
MESSAGE_FIELD_NUMBER: _ClassVar[int]
|
||||||
|
success: bool
|
||||||
|
message: str
|
||||||
|
def __init__(self, success: bool = ..., message: _Optional[str] = ...) -> None: ...
|
||||||
|
|
||||||
|
class LoadLoRARequest(_message.Message):
|
||||||
|
__slots__ = ("adapter_id", "adapter_path", "rank")
|
||||||
|
ADAPTER_ID_FIELD_NUMBER: _ClassVar[int]
|
||||||
|
ADAPTER_PATH_FIELD_NUMBER: _ClassVar[int]
|
||||||
|
RANK_FIELD_NUMBER: _ClassVar[int]
|
||||||
|
adapter_id: str
|
||||||
|
adapter_path: str
|
||||||
|
rank: int
|
||||||
|
def __init__(self, adapter_id: _Optional[str] = ..., adapter_path: _Optional[str] = ..., rank: _Optional[int] = ...) -> None: ...
|
||||||
|
|
||||||
|
class LoadLoRAResponse(_message.Message):
|
||||||
|
__slots__ = ("success", "adapter_id", "message")
|
||||||
|
SUCCESS_FIELD_NUMBER: _ClassVar[int]
|
||||||
|
ADAPTER_ID_FIELD_NUMBER: _ClassVar[int]
|
||||||
|
MESSAGE_FIELD_NUMBER: _ClassVar[int]
|
||||||
|
success: bool
|
||||||
|
adapter_id: str
|
||||||
|
message: str
|
||||||
|
def __init__(self, success: bool = ..., adapter_id: _Optional[str] = ..., message: _Optional[str] = ...) -> None: ...
|
||||||
|
|
||||||
|
class UnloadLoRARequest(_message.Message):
|
||||||
|
__slots__ = ("adapter_id",)
|
||||||
|
ADAPTER_ID_FIELD_NUMBER: _ClassVar[int]
|
||||||
|
adapter_id: str
|
||||||
|
def __init__(self, adapter_id: _Optional[str] = ...) -> None: ...
|
||||||
|
|
||||||
|
class UnloadLoRAResponse(_message.Message):
|
||||||
|
__slots__ = ("success", "message")
|
||||||
|
SUCCESS_FIELD_NUMBER: _ClassVar[int]
|
||||||
|
MESSAGE_FIELD_NUMBER: _ClassVar[int]
|
||||||
|
success: bool
|
||||||
|
message: str
|
||||||
|
def __init__(self, success: bool = ..., message: _Optional[str] = ...) -> None: ...
|
||||||
|
|
||||||
|
class UpdateWeightsRequest(_message.Message):
|
||||||
|
__slots__ = ("disk_path", "tensor_data", "remote_url", "weight_name")
|
||||||
|
DISK_PATH_FIELD_NUMBER: _ClassVar[int]
|
||||||
|
TENSOR_DATA_FIELD_NUMBER: _ClassVar[int]
|
||||||
|
REMOTE_URL_FIELD_NUMBER: _ClassVar[int]
|
||||||
|
WEIGHT_NAME_FIELD_NUMBER: _ClassVar[int]
|
||||||
|
disk_path: str
|
||||||
|
tensor_data: bytes
|
||||||
|
remote_url: str
|
||||||
|
weight_name: str
|
||||||
|
def __init__(self, disk_path: _Optional[str] = ..., tensor_data: _Optional[bytes] = ..., remote_url: _Optional[str] = ..., weight_name: _Optional[str] = ...) -> None: ...
|
||||||
|
|
||||||
|
class UpdateWeightsResponse(_message.Message):
|
||||||
|
__slots__ = ("success", "message")
|
||||||
|
SUCCESS_FIELD_NUMBER: _ClassVar[int]
|
||||||
|
MESSAGE_FIELD_NUMBER: _ClassVar[int]
|
||||||
|
success: bool
|
||||||
|
message: str
|
||||||
|
def __init__(self, success: bool = ..., message: _Optional[str] = ...) -> None: ...
|
||||||
|
|
||||||
|
class GetInternalStateRequest(_message.Message):
|
||||||
|
__slots__ = ("state_keys",)
|
||||||
|
STATE_KEYS_FIELD_NUMBER: _ClassVar[int]
|
||||||
|
state_keys: _containers.RepeatedScalarFieldContainer[str]
|
||||||
|
def __init__(self, state_keys: _Optional[_Iterable[str]] = ...) -> None: ...
|
||||||
|
|
||||||
|
class GetInternalStateResponse(_message.Message):
|
||||||
|
__slots__ = ("state",)
|
||||||
|
STATE_FIELD_NUMBER: _ClassVar[int]
|
||||||
|
state: _struct_pb2.Struct
|
||||||
|
def __init__(self, state: _Optional[_Union[_struct_pb2.Struct, _Mapping]] = ...) -> None: ...
|
||||||
|
|
||||||
|
class SetInternalStateRequest(_message.Message):
|
||||||
|
__slots__ = ("state",)
|
||||||
|
STATE_FIELD_NUMBER: _ClassVar[int]
|
||||||
|
state: _struct_pb2.Struct
|
||||||
|
def __init__(self, state: _Optional[_Union[_struct_pb2.Struct, _Mapping]] = ...) -> None: ...
|
||||||
|
|
||||||
|
class SetInternalStateResponse(_message.Message):
|
||||||
|
__slots__ = ("success", "message")
|
||||||
|
SUCCESS_FIELD_NUMBER: _ClassVar[int]
|
||||||
|
MESSAGE_FIELD_NUMBER: _ClassVar[int]
|
||||||
|
success: bool
|
||||||
|
message: str
|
||||||
|
def __init__(self, success: bool = ..., message: _Optional[str] = ...) -> None: ...
|
||||||
236
python/sglang/srt/grpc/sglang_scheduler_pb2_grpc.py
Normal file
236
python/sglang/srt/grpc/sglang_scheduler_pb2_grpc.py
Normal file
@@ -0,0 +1,236 @@
|
|||||||
|
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
|
||||||
|
"""Client and server classes corresponding to protobuf-defined services."""
|
||||||
|
import grpc
|
||||||
|
import warnings
|
||||||
|
|
||||||
|
from . import sglang_scheduler_pb2 as sglang__scheduler__pb2
|
||||||
|
|
||||||
|
GRPC_GENERATED_VERSION = '1.74.0'
|
||||||
|
GRPC_VERSION = grpc.__version__
|
||||||
|
_version_not_supported = False
|
||||||
|
|
||||||
|
try:
|
||||||
|
from grpc._utilities import first_version_is_lower
|
||||||
|
_version_not_supported = first_version_is_lower(GRPC_VERSION, GRPC_GENERATED_VERSION)
|
||||||
|
except ImportError:
|
||||||
|
_version_not_supported = True
|
||||||
|
|
||||||
|
if _version_not_supported:
|
||||||
|
raise RuntimeError(
|
||||||
|
f'The grpc package installed is at version {GRPC_VERSION},'
|
||||||
|
+ f' but the generated code in sglang_scheduler_pb2_grpc.py depends on'
|
||||||
|
+ f' grpcio>={GRPC_GENERATED_VERSION}.'
|
||||||
|
+ f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}'
|
||||||
|
+ f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.'
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class SglangSchedulerStub(object):
|
||||||
|
"""Service definition for SGLang scheduler communication
|
||||||
|
This protocol bridges the Rust router and Python scheduler
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, channel):
|
||||||
|
"""Constructor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
channel: A grpc.Channel.
|
||||||
|
"""
|
||||||
|
self.Generate = channel.unary_stream(
|
||||||
|
'/sglang.grpc.scheduler.SglangScheduler/Generate',
|
||||||
|
request_serializer=sglang__scheduler__pb2.GenerateRequest.SerializeToString,
|
||||||
|
response_deserializer=sglang__scheduler__pb2.GenerateResponse.FromString,
|
||||||
|
_registered_method=True)
|
||||||
|
self.Embed = channel.unary_unary(
|
||||||
|
'/sglang.grpc.scheduler.SglangScheduler/Embed',
|
||||||
|
request_serializer=sglang__scheduler__pb2.EmbedRequest.SerializeToString,
|
||||||
|
response_deserializer=sglang__scheduler__pb2.EmbedResponse.FromString,
|
||||||
|
_registered_method=True)
|
||||||
|
self.HealthCheck = channel.unary_unary(
|
||||||
|
'/sglang.grpc.scheduler.SglangScheduler/HealthCheck',
|
||||||
|
request_serializer=sglang__scheduler__pb2.HealthCheckRequest.SerializeToString,
|
||||||
|
response_deserializer=sglang__scheduler__pb2.HealthCheckResponse.FromString,
|
||||||
|
_registered_method=True)
|
||||||
|
self.Abort = channel.unary_unary(
|
||||||
|
'/sglang.grpc.scheduler.SglangScheduler/Abort',
|
||||||
|
request_serializer=sglang__scheduler__pb2.AbortRequest.SerializeToString,
|
||||||
|
response_deserializer=sglang__scheduler__pb2.AbortResponse.FromString,
|
||||||
|
_registered_method=True)
|
||||||
|
|
||||||
|
|
||||||
|
class SglangSchedulerServicer(object):
|
||||||
|
"""Service definition for SGLang scheduler communication
|
||||||
|
This protocol bridges the Rust router and Python scheduler
|
||||||
|
"""
|
||||||
|
|
||||||
|
def Generate(self, request, context):
|
||||||
|
"""Submit a generation request (supports streaming)
|
||||||
|
"""
|
||||||
|
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||||
|
context.set_details('Method not implemented!')
|
||||||
|
raise NotImplementedError('Method not implemented!')
|
||||||
|
|
||||||
|
def Embed(self, request, context):
|
||||||
|
"""Submit an embedding request
|
||||||
|
"""
|
||||||
|
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||||
|
context.set_details('Method not implemented!')
|
||||||
|
raise NotImplementedError('Method not implemented!')
|
||||||
|
|
||||||
|
def HealthCheck(self, request, context):
|
||||||
|
"""Health check and metrics
|
||||||
|
"""
|
||||||
|
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||||
|
context.set_details('Method not implemented!')
|
||||||
|
raise NotImplementedError('Method not implemented!')
|
||||||
|
|
||||||
|
def Abort(self, request, context):
|
||||||
|
"""Abort a running request
|
||||||
|
"""
|
||||||
|
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||||
|
context.set_details('Method not implemented!')
|
||||||
|
raise NotImplementedError('Method not implemented!')
|
||||||
|
|
||||||
|
|
||||||
|
def add_SglangSchedulerServicer_to_server(servicer, server):
|
||||||
|
rpc_method_handlers = {
|
||||||
|
'Generate': grpc.unary_stream_rpc_method_handler(
|
||||||
|
servicer.Generate,
|
||||||
|
request_deserializer=sglang__scheduler__pb2.GenerateRequest.FromString,
|
||||||
|
response_serializer=sglang__scheduler__pb2.GenerateResponse.SerializeToString,
|
||||||
|
),
|
||||||
|
'Embed': grpc.unary_unary_rpc_method_handler(
|
||||||
|
servicer.Embed,
|
||||||
|
request_deserializer=sglang__scheduler__pb2.EmbedRequest.FromString,
|
||||||
|
response_serializer=sglang__scheduler__pb2.EmbedResponse.SerializeToString,
|
||||||
|
),
|
||||||
|
'HealthCheck': grpc.unary_unary_rpc_method_handler(
|
||||||
|
servicer.HealthCheck,
|
||||||
|
request_deserializer=sglang__scheduler__pb2.HealthCheckRequest.FromString,
|
||||||
|
response_serializer=sglang__scheduler__pb2.HealthCheckResponse.SerializeToString,
|
||||||
|
),
|
||||||
|
'Abort': grpc.unary_unary_rpc_method_handler(
|
||||||
|
servicer.Abort,
|
||||||
|
request_deserializer=sglang__scheduler__pb2.AbortRequest.FromString,
|
||||||
|
response_serializer=sglang__scheduler__pb2.AbortResponse.SerializeToString,
|
||||||
|
),
|
||||||
|
}
|
||||||
|
generic_handler = grpc.method_handlers_generic_handler(
|
||||||
|
'sglang.grpc.scheduler.SglangScheduler', rpc_method_handlers)
|
||||||
|
server.add_generic_rpc_handlers((generic_handler,))
|
||||||
|
server.add_registered_method_handlers('sglang.grpc.scheduler.SglangScheduler', rpc_method_handlers)
|
||||||
|
|
||||||
|
|
||||||
|
# This class is part of an EXPERIMENTAL API.
|
||||||
|
class SglangScheduler(object):
|
||||||
|
"""Service definition for SGLang scheduler communication
|
||||||
|
This protocol bridges the Rust router and Python scheduler
|
||||||
|
"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def Generate(request,
|
||||||
|
target,
|
||||||
|
options=(),
|
||||||
|
channel_credentials=None,
|
||||||
|
call_credentials=None,
|
||||||
|
insecure=False,
|
||||||
|
compression=None,
|
||||||
|
wait_for_ready=None,
|
||||||
|
timeout=None,
|
||||||
|
metadata=None):
|
||||||
|
return grpc.experimental.unary_stream(
|
||||||
|
request,
|
||||||
|
target,
|
||||||
|
'/sglang.grpc.scheduler.SglangScheduler/Generate',
|
||||||
|
sglang__scheduler__pb2.GenerateRequest.SerializeToString,
|
||||||
|
sglang__scheduler__pb2.GenerateResponse.FromString,
|
||||||
|
options,
|
||||||
|
channel_credentials,
|
||||||
|
insecure,
|
||||||
|
call_credentials,
|
||||||
|
compression,
|
||||||
|
wait_for_ready,
|
||||||
|
timeout,
|
||||||
|
metadata,
|
||||||
|
_registered_method=True)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def Embed(request,
|
||||||
|
target,
|
||||||
|
options=(),
|
||||||
|
channel_credentials=None,
|
||||||
|
call_credentials=None,
|
||||||
|
insecure=False,
|
||||||
|
compression=None,
|
||||||
|
wait_for_ready=None,
|
||||||
|
timeout=None,
|
||||||
|
metadata=None):
|
||||||
|
return grpc.experimental.unary_unary(
|
||||||
|
request,
|
||||||
|
target,
|
||||||
|
'/sglang.grpc.scheduler.SglangScheduler/Embed',
|
||||||
|
sglang__scheduler__pb2.EmbedRequest.SerializeToString,
|
||||||
|
sglang__scheduler__pb2.EmbedResponse.FromString,
|
||||||
|
options,
|
||||||
|
channel_credentials,
|
||||||
|
insecure,
|
||||||
|
call_credentials,
|
||||||
|
compression,
|
||||||
|
wait_for_ready,
|
||||||
|
timeout,
|
||||||
|
metadata,
|
||||||
|
_registered_method=True)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def HealthCheck(request,
|
||||||
|
target,
|
||||||
|
options=(),
|
||||||
|
channel_credentials=None,
|
||||||
|
call_credentials=None,
|
||||||
|
insecure=False,
|
||||||
|
compression=None,
|
||||||
|
wait_for_ready=None,
|
||||||
|
timeout=None,
|
||||||
|
metadata=None):
|
||||||
|
return grpc.experimental.unary_unary(
|
||||||
|
request,
|
||||||
|
target,
|
||||||
|
'/sglang.grpc.scheduler.SglangScheduler/HealthCheck',
|
||||||
|
sglang__scheduler__pb2.HealthCheckRequest.SerializeToString,
|
||||||
|
sglang__scheduler__pb2.HealthCheckResponse.FromString,
|
||||||
|
options,
|
||||||
|
channel_credentials,
|
||||||
|
insecure,
|
||||||
|
call_credentials,
|
||||||
|
compression,
|
||||||
|
wait_for_ready,
|
||||||
|
timeout,
|
||||||
|
metadata,
|
||||||
|
_registered_method=True)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def Abort(request,
|
||||||
|
target,
|
||||||
|
options=(),
|
||||||
|
channel_credentials=None,
|
||||||
|
call_credentials=None,
|
||||||
|
insecure=False,
|
||||||
|
compression=None,
|
||||||
|
wait_for_ready=None,
|
||||||
|
timeout=None,
|
||||||
|
metadata=None):
|
||||||
|
return grpc.experimental.unary_unary(
|
||||||
|
request,
|
||||||
|
target,
|
||||||
|
'/sglang.grpc.scheduler.SglangScheduler/Abort',
|
||||||
|
sglang__scheduler__pb2.AbortRequest.SerializeToString,
|
||||||
|
sglang__scheduler__pb2.AbortResponse.FromString,
|
||||||
|
options,
|
||||||
|
channel_credentials,
|
||||||
|
insecure,
|
||||||
|
call_credentials,
|
||||||
|
compression,
|
||||||
|
wait_for_ready,
|
||||||
|
timeout,
|
||||||
|
metadata,
|
||||||
|
_registered_method=True)
|
||||||
@@ -2238,6 +2238,7 @@ class ServerArgs:
|
|||||||
args.pp_size = args.pipeline_parallel_size
|
args.pp_size = args.pipeline_parallel_size
|
||||||
args.dp_size = args.data_parallel_size
|
args.dp_size = args.data_parallel_size
|
||||||
args.ep_size = args.expert_parallel_size
|
args.ep_size = args.expert_parallel_size
|
||||||
|
|
||||||
attrs = [attr.name for attr in dataclasses.fields(cls)]
|
attrs = [attr.name for attr in dataclasses.fields(cls)]
|
||||||
return cls(**{attr: getattr(args, attr) for attr in attrs})
|
return cls(**{attr: getattr(args, attr) for attr in attrs})
|
||||||
|
|
||||||
|
|||||||
@@ -37,21 +37,6 @@ impl SglangSchedulerClient {
|
|||||||
Ok(Self { client })
|
Ok(Self { client })
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Initialize the connection
|
|
||||||
pub async fn initialize(
|
|
||||||
&mut self,
|
|
||||||
client_id: String,
|
|
||||||
) -> Result<proto::InitializeResponse, Box<dyn std::error::Error>> {
|
|
||||||
let request = Request::new(proto::InitializeRequest {
|
|
||||||
client_id,
|
|
||||||
client_version: "0.1.0".to_string(),
|
|
||||||
mode: proto::initialize_request::Mode::Regular as i32,
|
|
||||||
});
|
|
||||||
|
|
||||||
let response = self.client.initialize(request).await?;
|
|
||||||
Ok(response.into_inner())
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Submit a generation request (returns streaming response)
|
/// Submit a generation request (returns streaming response)
|
||||||
pub async fn generate_stream(
|
pub async fn generate_stream(
|
||||||
&mut self,
|
&mut self,
|
||||||
@@ -68,7 +53,10 @@ impl SglangSchedulerClient {
|
|||||||
) -> Result<proto::HealthCheckResponse, Box<dyn std::error::Error>> {
|
) -> Result<proto::HealthCheckResponse, Box<dyn std::error::Error>> {
|
||||||
debug!("Sending health check request");
|
debug!("Sending health check request");
|
||||||
let request = Request::new(proto::HealthCheckRequest {
|
let request = Request::new(proto::HealthCheckRequest {
|
||||||
include_detailed_metrics: false,
|
tokenized: Some(proto::TokenizedInput {
|
||||||
|
original_text: "Hello".to_string(),
|
||||||
|
input_ids: vec![9906], // Mock token ID for "Hello"
|
||||||
|
}),
|
||||||
});
|
});
|
||||||
|
|
||||||
let response = self.client.health_check(request).await?;
|
let response = self.client.health_check(request).await?;
|
||||||
@@ -87,21 +75,6 @@ impl SglangSchedulerClient {
|
|||||||
self.client.abort(request).await?;
|
self.client.abort(request).await?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Flush cache
|
|
||||||
pub async fn flush_cache(
|
|
||||||
&mut self,
|
|
||||||
flush_all: bool,
|
|
||||||
session_ids: &[String],
|
|
||||||
) -> Result<proto::FlushCacheResponse, Box<dyn std::error::Error>> {
|
|
||||||
let request = Request::new(proto::FlushCacheRequest {
|
|
||||||
flush_all,
|
|
||||||
session_ids: session_ids.to_vec(),
|
|
||||||
});
|
|
||||||
|
|
||||||
let response = self.client.flush_cache(request).await?;
|
|
||||||
Ok(response.into_inner())
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
@@ -111,14 +84,13 @@ mod tests {
|
|||||||
#[test]
|
#[test]
|
||||||
fn test_proto_types_compilation() {
|
fn test_proto_types_compilation() {
|
||||||
// Test that protobuf types can be constructed
|
// Test that protobuf types can be constructed
|
||||||
let init_req = proto::InitializeRequest {
|
let health_req = proto::HealthCheckRequest {
|
||||||
client_id: "test-client".to_string(),
|
tokenized: Some(proto::TokenizedInput {
|
||||||
client_version: "0.1.0".to_string(),
|
original_text: "test".to_string(),
|
||||||
mode: 0,
|
input_ids: vec![1296],
|
||||||
|
}),
|
||||||
};
|
};
|
||||||
assert_eq!(init_req.client_id, "test-client");
|
assert!(health_req.tokenized.is_some());
|
||||||
assert_eq!(init_req.client_version, "0.1.0");
|
|
||||||
assert_eq!(init_req.mode, 0);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@@ -134,9 +106,10 @@ mod tests {
|
|||||||
|
|
||||||
let gen_req = proto::GenerateRequest {
|
let gen_req = proto::GenerateRequest {
|
||||||
request_id: "test-req-123".to_string(),
|
request_id: "test-req-123".to_string(),
|
||||||
input: Some(proto::generate_request::Input::Text(
|
tokenized: Some(proto::TokenizedInput {
|
||||||
"Hello world".to_string(),
|
original_text: "Hello world".to_string(),
|
||||||
)),
|
input_ids: vec![9906, 1917], // Mock token IDs for "Hello world"
|
||||||
|
}),
|
||||||
sampling_params: Some(sampling_params),
|
sampling_params: Some(sampling_params),
|
||||||
return_logprob: true,
|
return_logprob: true,
|
||||||
logprob_start_len: 0,
|
logprob_start_len: 0,
|
||||||
@@ -145,8 +118,8 @@ mod tests {
|
|||||||
};
|
};
|
||||||
|
|
||||||
assert_eq!(gen_req.request_id, "test-req-123");
|
assert_eq!(gen_req.request_id, "test-req-123");
|
||||||
if let Some(proto::generate_request::Input::Text(text)) = &gen_req.input {
|
if let Some(ref tokenized) = &gen_req.tokenized {
|
||||||
assert_eq!(text, "Hello world");
|
assert_eq!(tokenized.original_text, "Hello world");
|
||||||
}
|
}
|
||||||
assert!(gen_req.return_logprob);
|
assert!(gen_req.return_logprob);
|
||||||
assert_eq!(gen_req.top_logprobs_num, 5);
|
assert_eq!(gen_req.top_logprobs_num, 5);
|
||||||
@@ -160,9 +133,12 @@ mod tests {
|
|||||||
#[test]
|
#[test]
|
||||||
fn test_health_check_request() {
|
fn test_health_check_request() {
|
||||||
let health_req = proto::HealthCheckRequest {
|
let health_req = proto::HealthCheckRequest {
|
||||||
include_detailed_metrics: true,
|
tokenized: Some(proto::TokenizedInput {
|
||||||
|
original_text: "test".to_string(),
|
||||||
|
input_ids: vec![1296], // Mock token ID for "test"
|
||||||
|
}),
|
||||||
};
|
};
|
||||||
assert!(health_req.include_detailed_metrics);
|
assert!(health_req.tokenized.is_some());
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@@ -175,17 +151,6 @@ mod tests {
|
|||||||
assert_eq!(abort_req.reason, "User canceled");
|
assert_eq!(abort_req.reason, "User canceled");
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_flush_cache_request() {
|
|
||||||
let flush_req = proto::FlushCacheRequest {
|
|
||||||
flush_all: true,
|
|
||||||
session_ids: vec!["session1".to_string(), "session2".to_string()],
|
|
||||||
};
|
|
||||||
assert!(flush_req.flush_all);
|
|
||||||
assert_eq!(flush_req.session_ids.len(), 2);
|
|
||||||
assert_eq!(flush_req.session_ids[0], "session1");
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_sampling_params_defaults() {
|
fn test_sampling_params_defaults() {
|
||||||
let params = proto::SamplingParams::default();
|
let params = proto::SamplingParams::default();
|
||||||
@@ -214,38 +179,29 @@ mod tests {
|
|||||||
assert_eq!(mm_inputs.modalities[0], "image");
|
assert_eq!(mm_inputs.modalities[0], "image");
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
// TODO: SessionParams not in current proto - skip test
|
||||||
fn test_session_params() {
|
// #[test]
|
||||||
let session_params = proto::SessionParams {
|
// fn test_session_params() { ... }
|
||||||
session_id: "sess-789".to_string(),
|
|
||||||
request_id: "req-101".to_string(),
|
|
||||||
offset: 100,
|
|
||||||
replace: true,
|
|
||||||
drop_previous_output: false,
|
|
||||||
};
|
|
||||||
|
|
||||||
assert_eq!(session_params.session_id, "sess-789");
|
|
||||||
assert_eq!(session_params.request_id, "req-101");
|
|
||||||
assert_eq!(session_params.offset, 100);
|
|
||||||
assert!(session_params.replace);
|
|
||||||
assert!(!session_params.drop_previous_output);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_embed_request() {
|
fn test_embed_request() {
|
||||||
let embed_req = proto::EmbedRequest {
|
let embed_req = proto::EmbedRequest {
|
||||||
request_id: "embed-req-202".to_string(),
|
request_id: "embed-req-202".to_string(),
|
||||||
input: Some(proto::embed_request::Input::Text(
|
tokenized: Some(proto::TokenizedInput {
|
||||||
"This is a test sentence for embedding".to_string(),
|
original_text: "This is a test sentence for embedding".to_string(),
|
||||||
)),
|
input_ids: vec![2028, 374, 264, 1296, 11914, 369, 28537], // Mock token IDs
|
||||||
|
}),
|
||||||
log_metrics: true,
|
log_metrics: true,
|
||||||
data_parallel_rank: 0,
|
data_parallel_rank: 0,
|
||||||
..Default::default()
|
..Default::default()
|
||||||
};
|
};
|
||||||
|
|
||||||
assert_eq!(embed_req.request_id, "embed-req-202");
|
assert_eq!(embed_req.request_id, "embed-req-202");
|
||||||
if let Some(proto::embed_request::Input::Text(text)) = &embed_req.input {
|
if let Some(ref tokenized) = &embed_req.tokenized {
|
||||||
assert_eq!(text, "This is a test sentence for embedding");
|
assert_eq!(
|
||||||
|
tokenized.original_text,
|
||||||
|
"This is a test sentence for embedding"
|
||||||
|
);
|
||||||
}
|
}
|
||||||
assert!(embed_req.log_metrics);
|
assert!(embed_req.log_metrics);
|
||||||
assert_eq!(embed_req.data_parallel_rank, 0);
|
assert_eq!(embed_req.data_parallel_rank, 0);
|
||||||
@@ -292,36 +248,7 @@ mod tests {
|
|||||||
assert_eq!(chunk.queue_time, 10);
|
assert_eq!(chunk.queue_time, 10);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
// TODO: ModelInfo not in current proto - skip test
|
||||||
fn test_model_info() {
|
// #[test]
|
||||||
let model_info = proto::ModelInfo {
|
// fn test_model_info() { ... }
|
||||||
model_name: "Meta-Llama-3-8B-Instruct".to_string(),
|
|
||||||
max_context_length: 8192,
|
|
||||||
vocab_size: 128256,
|
|
||||||
supports_tool_calling: true,
|
|
||||||
supports_vision: false,
|
|
||||||
special_tokens: vec![
|
|
||||||
"<|begin_of_text|>".to_string(),
|
|
||||||
"<|end_of_text|>".to_string(),
|
|
||||||
],
|
|
||||||
model_type: "llama".to_string(),
|
|
||||||
num_layers: 32,
|
|
||||||
hidden_size: 4096,
|
|
||||||
num_attention_heads: 32,
|
|
||||||
num_key_value_heads: 8,
|
|
||||||
tokenizer_type: "llama".to_string(),
|
|
||||||
eos_token_ids: vec![128001, 128009],
|
|
||||||
pad_token_id: 128001,
|
|
||||||
bos_token_id: 128000,
|
|
||||||
};
|
|
||||||
|
|
||||||
assert_eq!(model_info.model_name, "Meta-Llama-3-8B-Instruct");
|
|
||||||
assert_eq!(model_info.max_context_length, 8192);
|
|
||||||
assert_eq!(model_info.vocab_size, 128256);
|
|
||||||
assert!(model_info.supports_tool_calling);
|
|
||||||
assert!(!model_info.supports_vision);
|
|
||||||
assert_eq!(model_info.special_tokens.len(), 2);
|
|
||||||
assert_eq!(model_info.num_layers, 32);
|
|
||||||
assert_eq!(model_info.eos_token_ids, vec![128001, 128009]);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -8,9 +8,6 @@ import "google/protobuf/struct.proto";
|
|||||||
// Service definition for SGLang scheduler communication
|
// Service definition for SGLang scheduler communication
|
||||||
// This protocol bridges the Rust router and Python scheduler
|
// This protocol bridges the Rust router and Python scheduler
|
||||||
service SglangScheduler {
|
service SglangScheduler {
|
||||||
// Initialize connection and get model info
|
|
||||||
rpc Initialize(InitializeRequest) returns (InitializeResponse);
|
|
||||||
|
|
||||||
// Submit a generation request (supports streaming)
|
// Submit a generation request (supports streaming)
|
||||||
rpc Generate(GenerateRequest) returns (stream GenerateResponse);
|
rpc Generate(GenerateRequest) returns (stream GenerateResponse);
|
||||||
|
|
||||||
@@ -23,8 +20,6 @@ service SglangScheduler {
|
|||||||
// Abort a running request
|
// Abort a running request
|
||||||
rpc Abort(AbortRequest) returns (AbortResponse);
|
rpc Abort(AbortRequest) returns (AbortResponse);
|
||||||
|
|
||||||
// Flush KV cache
|
|
||||||
rpc FlushCache(FlushCacheRequest) returns (FlushCacheResponse);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// =====================
|
// =====================
|
||||||
@@ -75,14 +70,6 @@ message SamplingParams {
|
|||||||
google.protobuf.Struct custom_params = 25;
|
google.protobuf.Struct custom_params = 25;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Session parameters for continual prompting
|
|
||||||
message SessionParams {
|
|
||||||
string session_id = 1;
|
|
||||||
string request_id = 2;
|
|
||||||
int32 offset = 3;
|
|
||||||
bool replace = 4;
|
|
||||||
bool drop_previous_output = 5;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Disaggregated serving parameters
|
// Disaggregated serving parameters
|
||||||
message DisaggregatedParams {
|
message DisaggregatedParams {
|
||||||
@@ -91,87 +78,6 @@ message DisaggregatedParams {
|
|||||||
int32 bootstrap_room = 3;
|
int32 bootstrap_room = 3;
|
||||||
}
|
}
|
||||||
|
|
||||||
// =====================
|
|
||||||
// Initialize
|
|
||||||
// =====================
|
|
||||||
|
|
||||||
message InitializeRequest {
|
|
||||||
string client_id = 1;
|
|
||||||
string client_version = 2;
|
|
||||||
|
|
||||||
// Operating mode
|
|
||||||
enum Mode {
|
|
||||||
REGULAR = 0; // Normal mode with local scheduler
|
|
||||||
PREFILL = 1; // Prefill-only mode for disaggregated serving
|
|
||||||
DECODE = 2; // Decode-only mode for disaggregated serving
|
|
||||||
}
|
|
||||||
Mode mode = 3;
|
|
||||||
}
|
|
||||||
|
|
||||||
message InitializeResponse {
|
|
||||||
bool success = 1;
|
|
||||||
string scheduler_version = 2;
|
|
||||||
|
|
||||||
// Model information
|
|
||||||
ModelInfo model_info = 3;
|
|
||||||
|
|
||||||
// Server capabilities
|
|
||||||
ServerCapabilities capabilities = 4;
|
|
||||||
|
|
||||||
// Error message if success is false
|
|
||||||
string error_message = 5;
|
|
||||||
}
|
|
||||||
|
|
||||||
message ModelInfo {
|
|
||||||
string model_name = 1;
|
|
||||||
int32 max_context_length = 2;
|
|
||||||
int32 vocab_size = 3;
|
|
||||||
bool supports_tool_calling = 4;
|
|
||||||
bool supports_vision = 5;
|
|
||||||
repeated string special_tokens = 6;
|
|
||||||
|
|
||||||
// Additional model metadata
|
|
||||||
string model_type = 7;
|
|
||||||
int32 num_layers = 8;
|
|
||||||
int32 hidden_size = 9;
|
|
||||||
int32 num_attention_heads = 10;
|
|
||||||
int32 num_key_value_heads = 11;
|
|
||||||
|
|
||||||
// Tokenizer info
|
|
||||||
string tokenizer_type = 12;
|
|
||||||
repeated int32 eos_token_ids = 13;
|
|
||||||
int32 pad_token_id = 14;
|
|
||||||
int32 bos_token_id = 15;
|
|
||||||
}
|
|
||||||
|
|
||||||
message ServerCapabilities {
|
|
||||||
bool continuous_batching = 1;
|
|
||||||
bool disaggregated_serving = 2;
|
|
||||||
bool speculative_decoding = 3;
|
|
||||||
int32 max_batch_size = 4;
|
|
||||||
int32 max_num_batched_tokens = 5;
|
|
||||||
int32 max_prefill_tokens = 6;
|
|
||||||
string attention_backend = 7; // "flashinfer", "triton", "torch"
|
|
||||||
|
|
||||||
// Additional capabilities
|
|
||||||
bool supports_lora = 8;
|
|
||||||
bool supports_grammar = 9;
|
|
||||||
bool supports_multimodal = 10;
|
|
||||||
repeated string supported_modalities = 11; // ["image", "video", "audio"]
|
|
||||||
bool supports_custom_logit_processor = 12;
|
|
||||||
bool supports_session = 13;
|
|
||||||
|
|
||||||
// Hardware info
|
|
||||||
int32 num_gpus = 14;
|
|
||||||
string gpu_type = 15;
|
|
||||||
int64 total_gpu_memory = 16;
|
|
||||||
|
|
||||||
// Parallelism info
|
|
||||||
int32 tensor_parallel_size = 17;
|
|
||||||
int32 pipeline_parallel_size = 18;
|
|
||||||
int32 data_parallel_size = 19;
|
|
||||||
}
|
|
||||||
|
|
||||||
// =====================
|
// =====================
|
||||||
// Generate Request
|
// Generate Request
|
||||||
// =====================
|
// =====================
|
||||||
@@ -179,49 +85,43 @@ message ServerCapabilities {
|
|||||||
message GenerateRequest {
|
message GenerateRequest {
|
||||||
string request_id = 1;
|
string request_id = 1;
|
||||||
|
|
||||||
// Input can be either text or tokenized
|
// Input must be tokenized (no raw text)
|
||||||
oneof input {
|
TokenizedInput tokenized = 2;
|
||||||
string text = 2;
|
|
||||||
TokenizedInput tokenized = 3;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Multimodal inputs
|
// Multimodal inputs
|
||||||
MultimodalInputs mm_inputs = 4;
|
MultimodalInputs mm_inputs = 3;
|
||||||
|
|
||||||
// Generation parameters
|
// Generation parameters
|
||||||
SamplingParams sampling_params = 5;
|
SamplingParams sampling_params = 4;
|
||||||
|
|
||||||
// Return options
|
// Return options
|
||||||
bool return_logprob = 6;
|
bool return_logprob = 5;
|
||||||
int32 logprob_start_len = 7;
|
int32 logprob_start_len = 6;
|
||||||
int32 top_logprobs_num = 8;
|
int32 top_logprobs_num = 7;
|
||||||
repeated int32 token_ids_logprob = 9;
|
repeated int32 token_ids_logprob = 8;
|
||||||
bool return_hidden_states = 10;
|
bool return_hidden_states = 9;
|
||||||
|
|
||||||
// Session management
|
|
||||||
SessionParams session_params = 11;
|
|
||||||
|
|
||||||
// For disaggregated serving
|
// For disaggregated serving
|
||||||
DisaggregatedParams disaggregated_params = 12;
|
DisaggregatedParams disaggregated_params = 10;
|
||||||
|
|
||||||
// Custom logit processor (serialized)
|
// Custom logit processor (serialized)
|
||||||
string custom_logit_processor = 13;
|
string custom_logit_processor = 11;
|
||||||
|
|
||||||
// Request metadata
|
// Request metadata
|
||||||
google.protobuf.Timestamp timestamp = 14;
|
google.protobuf.Timestamp timestamp = 12;
|
||||||
bool log_metrics = 15;
|
bool log_metrics = 13;
|
||||||
|
|
||||||
// Input embeddings (alternative to text/tokens)
|
// Input embeddings (alternative to text/tokens)
|
||||||
repeated float input_embeds = 16;
|
repeated float input_embeds = 14;
|
||||||
|
|
||||||
// LoRA adapter ID (if pre-loaded)
|
// LoRA adapter ID (if pre-loaded)
|
||||||
string lora_id = 17;
|
string lora_id = 15;
|
||||||
|
|
||||||
// Data parallel routing
|
// Data parallel routing
|
||||||
int32 data_parallel_rank = 18;
|
int32 data_parallel_rank = 16;
|
||||||
|
|
||||||
// For load balancing
|
// For load balancing
|
||||||
int32 dp_balance_id = 19;
|
int32 dp_balance_id = 17;
|
||||||
}
|
}
|
||||||
|
|
||||||
message TokenizedInput {
|
message TokenizedInput {
|
||||||
@@ -303,19 +203,6 @@ message GenerateComplete {
|
|||||||
}
|
}
|
||||||
FinishReason finish_reason = 3;
|
FinishReason finish_reason = 3;
|
||||||
|
|
||||||
// Final counts
|
|
||||||
int32 prompt_tokens = 4;
|
|
||||||
int32 completion_tokens = 5;
|
|
||||||
int32 cached_tokens = 6;
|
|
||||||
|
|
||||||
// Performance metrics
|
|
||||||
float total_generation_time = 7;
|
|
||||||
float time_to_first_token = 8;
|
|
||||||
float tokens_per_second = 9;
|
|
||||||
|
|
||||||
// Spec decode metrics
|
|
||||||
int32 spec_verify_count = 10;
|
|
||||||
|
|
||||||
// All logprobs if requested
|
// All logprobs if requested
|
||||||
repeated LogProbs all_logprobs = 11;
|
repeated LogProbs all_logprobs = 11;
|
||||||
|
|
||||||
@@ -359,10 +246,8 @@ message HiddenStates {
|
|||||||
message EmbedRequest {
|
message EmbedRequest {
|
||||||
string request_id = 1;
|
string request_id = 1;
|
||||||
|
|
||||||
oneof input {
|
// Input must be tokenized (no raw text)
|
||||||
string text = 2;
|
TokenizedInput tokenized = 2;
|
||||||
TokenizedInput tokenized = 3;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Multimodal inputs
|
// Multimodal inputs
|
||||||
MultimodalInputs mm_inputs = 4;
|
MultimodalInputs mm_inputs = 4;
|
||||||
@@ -422,39 +307,13 @@ message EmbedError {
|
|||||||
// =====================
|
// =====================
|
||||||
|
|
||||||
message HealthCheckRequest {
|
message HealthCheckRequest {
|
||||||
bool include_detailed_metrics = 1;
|
// Input for health test generation (must be tokenized)
|
||||||
|
TokenizedInput tokenized = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
message HealthCheckResponse {
|
message HealthCheckResponse {
|
||||||
bool healthy = 1;
|
bool healthy = 1;
|
||||||
|
string message = 2;
|
||||||
// Current load metrics
|
|
||||||
int32 num_requests_running = 2;
|
|
||||||
int32 num_requests_waiting = 3;
|
|
||||||
float gpu_cache_usage = 4;
|
|
||||||
float gpu_memory_usage = 5;
|
|
||||||
|
|
||||||
// KV cache metrics
|
|
||||||
int32 kv_cache_total_blocks = 6;
|
|
||||||
int32 kv_cache_used_blocks = 7;
|
|
||||||
float kv_cache_hit_rate = 8;
|
|
||||||
|
|
||||||
// Additional metrics
|
|
||||||
int32 num_grammar_queue_requests = 9;
|
|
||||||
float generation_throughput = 10; // tokens/sec
|
|
||||||
float average_queue_time = 11; // seconds
|
|
||||||
float average_generation_time = 12; // seconds
|
|
||||||
|
|
||||||
// System metrics
|
|
||||||
float cpu_usage = 13;
|
|
||||||
int64 memory_usage = 14;
|
|
||||||
|
|
||||||
// Disaggregation metrics
|
|
||||||
int32 num_prefill_requests = 15;
|
|
||||||
int32 num_decode_requests = 16;
|
|
||||||
|
|
||||||
// Detailed metrics (optional)
|
|
||||||
google.protobuf.Struct detailed_metrics = 17;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
message AbortRequest {
|
message AbortRequest {
|
||||||
@@ -467,17 +326,6 @@ message AbortResponse {
|
|||||||
string message = 2;
|
string message = 2;
|
||||||
}
|
}
|
||||||
|
|
||||||
message FlushCacheRequest {
|
|
||||||
bool flush_all = 1;
|
|
||||||
repeated string session_ids = 2; // Flush specific sessions
|
|
||||||
}
|
|
||||||
|
|
||||||
message FlushCacheResponse {
|
|
||||||
bool success = 1;
|
|
||||||
int32 num_entries_flushed = 2;
|
|
||||||
int64 memory_freed = 3; // bytes
|
|
||||||
string message = 4;
|
|
||||||
}
|
|
||||||
|
|
||||||
// =====================
|
// =====================
|
||||||
// Additional Operations (Future)
|
// Additional Operations (Future)
|
||||||
|
|||||||
Reference in New Issue
Block a user