[router] Add spec for sglang scheduler (#9322)
This commit is contained in:
541
sgl-router/src/proto/sglang_scheduler.proto
Normal file
541
sgl-router/src/proto/sglang_scheduler.proto
Normal file
@@ -0,0 +1,541 @@
|
||||
syntax = "proto3";
|
||||
|
||||
package sglang.grpc.scheduler;
|
||||
|
||||
import "google/protobuf/timestamp.proto";
|
||||
import "google/protobuf/struct.proto";
|
||||
|
||||
// Service definition for SGLang scheduler communication
|
||||
// This protocol bridges the Rust router and Python scheduler
|
||||
service SGLangScheduler {
|
||||
// Initialize connection and get model info
|
||||
rpc Initialize(InitializeRequest) returns (InitializeResponse);
|
||||
|
||||
// Submit a generation request (supports streaming)
|
||||
rpc Generate(GenerateRequest) returns (stream GenerateResponse);
|
||||
|
||||
// Submit an embedding request
|
||||
rpc Embed(EmbedRequest) returns (EmbedResponse);
|
||||
|
||||
// Health check and metrics
|
||||
rpc HealthCheck(HealthCheckRequest) returns (HealthCheckResponse);
|
||||
|
||||
// Abort a running request
|
||||
rpc AbortRequest(AbortRequest) returns (AbortResponse);
|
||||
|
||||
// Flush KV cache
|
||||
rpc FlushCache(FlushCacheRequest) returns (FlushCacheResponse);
|
||||
}
|
||||
|
||||
// =====================
|
||||
// Common Types
|
||||
// =====================
|
||||
|
||||
// Sampling parameters matching SGLang's SamplingParams
|
||||
message SamplingParams {
|
||||
float temperature = 1;
|
||||
float top_p = 2;
|
||||
int32 top_k = 3;
|
||||
float min_p = 4;
|
||||
float frequency_penalty = 5;
|
||||
float presence_penalty = 6;
|
||||
float repetition_penalty = 7;
|
||||
|
||||
int32 max_new_tokens = 8;
|
||||
repeated string stop = 9;
|
||||
repeated int32 stop_token_ids = 10;
|
||||
bool skip_special_tokens = 11;
|
||||
bool spaces_between_special_tokens = 12;
|
||||
|
||||
// Structured generation
|
||||
oneof constraint {
|
||||
string regex = 13;
|
||||
string json_schema = 14;
|
||||
string ebnf_grammar = 15;
|
||||
}
|
||||
|
||||
// LoRA adapter
|
||||
string lora_path = 16;
|
||||
|
||||
// Speculative decoding
|
||||
int32 n = 17; // Number of samples
|
||||
|
||||
// Token healing
|
||||
bool token_healing = 18;
|
||||
|
||||
// Additional parameters
|
||||
int32 min_new_tokens = 19;
|
||||
bool ignore_eos = 20;
|
||||
bool no_stop_trim = 21;
|
||||
int32 stream_interval = 22;
|
||||
map<string, float> logit_bias = 23;
|
||||
string structural_tag = 24;
|
||||
|
||||
// Custom parameters for extensibility
|
||||
google.protobuf.Struct custom_params = 25;
|
||||
}
|
||||
|
||||
// Session parameters for continual prompting
|
||||
message SessionParams {
|
||||
string session_id = 1;
|
||||
string request_id = 2;
|
||||
int32 offset = 3;
|
||||
bool replace = 4;
|
||||
bool drop_previous_output = 5;
|
||||
}
|
||||
|
||||
// Disaggregated serving parameters
|
||||
message DisaggregatedParams {
|
||||
string bootstrap_host = 1;
|
||||
int32 bootstrap_port = 2;
|
||||
int32 bootstrap_room = 3;
|
||||
}
|
||||
|
||||
// =====================
|
||||
// Initialize
|
||||
// =====================
|
||||
|
||||
message InitializeRequest {
|
||||
string client_id = 1;
|
||||
string client_version = 2;
|
||||
|
||||
// Operating mode
|
||||
enum Mode {
|
||||
REGULAR = 0; // Normal mode with local scheduler
|
||||
PREFILL = 1; // Prefill-only mode for disaggregated serving
|
||||
DECODE = 2; // Decode-only mode for disaggregated serving
|
||||
}
|
||||
Mode mode = 3;
|
||||
}
|
||||
|
||||
message InitializeResponse {
|
||||
bool success = 1;
|
||||
string scheduler_version = 2;
|
||||
|
||||
// Model information
|
||||
ModelInfo model_info = 3;
|
||||
|
||||
// Server capabilities
|
||||
ServerCapabilities capabilities = 4;
|
||||
|
||||
// Error message if success is false
|
||||
string error_message = 5;
|
||||
}
|
||||
|
||||
message ModelInfo {
|
||||
string model_name = 1;
|
||||
int32 max_context_length = 2;
|
||||
int32 vocab_size = 3;
|
||||
bool supports_tool_calling = 4;
|
||||
bool supports_vision = 5;
|
||||
repeated string special_tokens = 6;
|
||||
|
||||
// Additional model metadata
|
||||
string model_type = 7;
|
||||
int32 num_layers = 8;
|
||||
int32 hidden_size = 9;
|
||||
int32 num_attention_heads = 10;
|
||||
int32 num_key_value_heads = 11;
|
||||
|
||||
// Tokenizer info
|
||||
string tokenizer_type = 12;
|
||||
repeated int32 eos_token_ids = 13;
|
||||
int32 pad_token_id = 14;
|
||||
int32 bos_token_id = 15;
|
||||
}
|
||||
|
||||
message ServerCapabilities {
|
||||
bool continuous_batching = 1;
|
||||
bool disaggregated_serving = 2;
|
||||
bool speculative_decoding = 3;
|
||||
int32 max_batch_size = 4;
|
||||
int32 max_num_batched_tokens = 5;
|
||||
int32 max_prefill_tokens = 6;
|
||||
string attention_backend = 7; // "flashinfer", "triton", "torch"
|
||||
|
||||
// Additional capabilities
|
||||
bool supports_lora = 8;
|
||||
bool supports_grammar = 9;
|
||||
bool supports_multimodal = 10;
|
||||
repeated string supported_modalities = 11; // ["image", "video", "audio"]
|
||||
bool supports_custom_logit_processor = 12;
|
||||
bool supports_session = 13;
|
||||
|
||||
// Hardware info
|
||||
int32 num_gpus = 14;
|
||||
string gpu_type = 15;
|
||||
int64 total_gpu_memory = 16;
|
||||
|
||||
// Parallelism info
|
||||
int32 tensor_parallel_size = 17;
|
||||
int32 pipeline_parallel_size = 18;
|
||||
int32 data_parallel_size = 19;
|
||||
}
|
||||
|
||||
// =====================
|
||||
// Generate Request
|
||||
// =====================
|
||||
|
||||
message GenerateRequest {
|
||||
string request_id = 1;
|
||||
|
||||
// Input can be either text or tokenized
|
||||
oneof input {
|
||||
string text = 2;
|
||||
TokenizedInput tokenized = 3;
|
||||
}
|
||||
|
||||
// Multimodal inputs
|
||||
MultimodalInputs mm_inputs = 4;
|
||||
|
||||
// Generation parameters
|
||||
SamplingParams sampling_params = 5;
|
||||
|
||||
// Return options
|
||||
bool return_logprob = 6;
|
||||
int32 logprob_start_len = 7;
|
||||
int32 top_logprobs_num = 8;
|
||||
repeated int32 token_ids_logprob = 9;
|
||||
bool return_hidden_states = 10;
|
||||
|
||||
// Session management
|
||||
SessionParams session_params = 11;
|
||||
|
||||
// For disaggregated serving
|
||||
DisaggregatedParams disaggregated_params = 12;
|
||||
|
||||
// Custom logit processor (serialized)
|
||||
string custom_logit_processor = 13;
|
||||
|
||||
// Request metadata
|
||||
google.protobuf.Timestamp timestamp = 14;
|
||||
bool log_metrics = 15;
|
||||
|
||||
// Input embeddings (alternative to text/tokens)
|
||||
repeated float input_embeds = 16;
|
||||
|
||||
// LoRA adapter ID (if pre-loaded)
|
||||
string lora_id = 17;
|
||||
|
||||
// Data parallel routing
|
||||
int32 data_parallel_rank = 18;
|
||||
|
||||
// For load balancing
|
||||
int32 dp_balance_id = 19;
|
||||
}
|
||||
|
||||
message TokenizedInput {
|
||||
string original_text = 1; // For reference
|
||||
repeated int32 input_ids = 2;
|
||||
}
|
||||
|
||||
message MultimodalInputs {
|
||||
// Simplified multimodal handling - actual data processed by tokenizer
|
||||
repeated string image_urls = 1;
|
||||
repeated string video_urls = 2;
|
||||
repeated string audio_urls = 3;
|
||||
|
||||
// Pre-processed multimodal features (if available)
|
||||
google.protobuf.Struct processed_features = 4;
|
||||
|
||||
// Raw data for direct processing
|
||||
repeated bytes image_data = 5;
|
||||
repeated bytes video_data = 6;
|
||||
repeated bytes audio_data = 7;
|
||||
|
||||
// Modality metadata
|
||||
repeated string modalities = 8;
|
||||
}
|
||||
|
||||
// =====================
|
||||
// Generate Response
|
||||
// =====================
|
||||
|
||||
message GenerateResponse {
|
||||
string request_id = 1;
|
||||
|
||||
// Response type
|
||||
oneof response {
|
||||
GenerateStreamChunk chunk = 2;
|
||||
GenerateComplete complete = 3;
|
||||
GenerateError error = 4;
|
||||
}
|
||||
}
|
||||
|
||||
message GenerateStreamChunk {
|
||||
// Generated token
|
||||
int32 token_id = 1;
|
||||
string text = 2;
|
||||
|
||||
// Cumulative counts
|
||||
int32 prompt_tokens = 3;
|
||||
int32 completion_tokens = 4;
|
||||
int32 cached_tokens = 5;
|
||||
|
||||
// Logprobs (if requested)
|
||||
LogProbs logprobs = 6;
|
||||
|
||||
// Hidden states (if requested)
|
||||
repeated float hidden_states = 7;
|
||||
|
||||
// Metadata
|
||||
float generation_time = 8; // Time to generate this token
|
||||
int32 queue_time = 9; // Time spent in queue
|
||||
}
|
||||
|
||||
message GenerateComplete {
|
||||
// Final output
|
||||
repeated int32 output_ids = 1;
|
||||
string output_text = 2;
|
||||
|
||||
// Finish reason
|
||||
enum FinishReason {
|
||||
// The model generated a stop sequence.
|
||||
STOP = 0;
|
||||
// The model reached the maximum generation length.
|
||||
LENGTH = 1;
|
||||
// The model generated an end-of-sequence (EOS) token.
|
||||
EOS_TOKEN = 2;
|
||||
// The model generated a user-provided stop string.
|
||||
STOP_STR = 3;
|
||||
// The request was aborted by the user or system.
|
||||
ABORT = 4;
|
||||
}
|
||||
FinishReason finish_reason = 3;
|
||||
|
||||
// Final counts
|
||||
int32 prompt_tokens = 4;
|
||||
int32 completion_tokens = 5;
|
||||
int32 cached_tokens = 6;
|
||||
|
||||
// Performance metrics
|
||||
float total_generation_time = 7;
|
||||
float time_to_first_token = 8;
|
||||
float tokens_per_second = 9;
|
||||
|
||||
// Spec decode metrics
|
||||
int32 spec_verify_count = 10;
|
||||
|
||||
// All logprobs if requested
|
||||
repeated LogProbs all_logprobs = 11;
|
||||
|
||||
// All hidden states if requested
|
||||
repeated HiddenStates all_hidden_states = 12;
|
||||
}
|
||||
|
||||
message GenerateError {
|
||||
string message = 1;
|
||||
string http_status_code = 2;
|
||||
string details = 3;
|
||||
}
|
||||
|
||||
message LogProbs {
|
||||
repeated float token_logprobs = 1;
|
||||
repeated int32 token_ids = 2;
|
||||
|
||||
// Top logprobs at each position
|
||||
repeated TopLogProbs top_logprobs = 3;
|
||||
|
||||
// Decoded text for tokens
|
||||
repeated string token_texts = 4;
|
||||
}
|
||||
|
||||
message TopLogProbs {
|
||||
repeated float values = 1;
|
||||
repeated int32 token_ids = 2;
|
||||
repeated string token_texts = 3;
|
||||
}
|
||||
|
||||
message HiddenStates {
|
||||
repeated float values = 1;
|
||||
int32 layer = 2;
|
||||
int32 position = 3;
|
||||
}
|
||||
|
||||
// =====================
|
||||
// Embedding Request
|
||||
// =====================
|
||||
|
||||
message EmbedRequest {
|
||||
string request_id = 1;
|
||||
|
||||
oneof input {
|
||||
string text = 2;
|
||||
TokenizedInput tokenized = 3;
|
||||
}
|
||||
|
||||
// Multimodal inputs
|
||||
MultimodalInputs mm_inputs = 4;
|
||||
|
||||
// Dummy sampling params for compatibility
|
||||
// EmbedRequest doesn't use sampling_params
|
||||
SamplingParams sampling_params = 5;
|
||||
|
||||
bool log_metrics = 6;
|
||||
|
||||
// Token type IDs for models that require them
|
||||
repeated int32 token_type_ids = 7;
|
||||
|
||||
// Data parallel routing
|
||||
int32 data_parallel_rank = 8;
|
||||
|
||||
// For cross-encoder requests
|
||||
bool is_cross_encoder = 9;
|
||||
repeated string texts = 10; // For cross-encoder batch
|
||||
}
|
||||
|
||||
message EmbedResponse {
|
||||
string request_id = 1;
|
||||
|
||||
oneof response {
|
||||
EmbedComplete complete = 2;
|
||||
EmbedError error = 3;
|
||||
}
|
||||
}
|
||||
|
||||
message EmbedComplete {
|
||||
repeated float embedding = 1;
|
||||
int32 prompt_tokens = 2;
|
||||
int32 cached_tokens = 3;
|
||||
|
||||
// Additional metadata
|
||||
int32 embedding_dim = 4;
|
||||
float generation_time = 5;
|
||||
|
||||
// For batch embeddings
|
||||
repeated Embedding batch_embeddings = 6;
|
||||
}
|
||||
|
||||
message Embedding {
|
||||
repeated float values = 1;
|
||||
int32 index = 2;
|
||||
}
|
||||
|
||||
message EmbedError {
|
||||
string message = 1;
|
||||
string code = 2;
|
||||
string details = 3;
|
||||
}
|
||||
|
||||
// =====================
|
||||
// Management Operations
|
||||
// =====================
|
||||
|
||||
message HealthCheckRequest {
|
||||
bool include_detailed_metrics = 1;
|
||||
}
|
||||
|
||||
message HealthCheckResponse {
|
||||
bool healthy = 1;
|
||||
|
||||
// Current load metrics
|
||||
int32 num_requests_running = 2;
|
||||
int32 num_requests_waiting = 3;
|
||||
float gpu_cache_usage = 4;
|
||||
float gpu_memory_usage = 5;
|
||||
|
||||
// KV cache metrics
|
||||
int32 kv_cache_total_blocks = 6;
|
||||
int32 kv_cache_used_blocks = 7;
|
||||
float kv_cache_hit_rate = 8;
|
||||
|
||||
// Additional metrics
|
||||
int32 num_grammar_queue_requests = 9;
|
||||
float generation_throughput = 10; // tokens/sec
|
||||
float average_queue_time = 11; // seconds
|
||||
float average_generation_time = 12; // seconds
|
||||
|
||||
// System metrics
|
||||
float cpu_usage = 13;
|
||||
int64 memory_usage = 14;
|
||||
|
||||
// Disaggregation metrics
|
||||
int32 num_prefill_requests = 15;
|
||||
int32 num_decode_requests = 16;
|
||||
|
||||
// Detailed metrics (optional)
|
||||
google.protobuf.Struct detailed_metrics = 17;
|
||||
}
|
||||
|
||||
message AbortRequest {
|
||||
string request_id = 1;
|
||||
string reason = 2;
|
||||
}
|
||||
|
||||
message AbortResponse {
|
||||
bool success = 1;
|
||||
string message = 2;
|
||||
}
|
||||
|
||||
message FlushCacheRequest {
|
||||
bool flush_all = 1;
|
||||
repeated string session_ids = 2; // Flush specific sessions
|
||||
}
|
||||
|
||||
message FlushCacheResponse {
|
||||
bool success = 1;
|
||||
int32 num_entries_flushed = 2;
|
||||
int64 memory_freed = 3; // bytes
|
||||
string message = 4;
|
||||
}
|
||||
|
||||
// =====================
|
||||
// Additional Operations (Future)
|
||||
// =====================
|
||||
|
||||
// Load LoRA adapter
|
||||
message LoadLoRARequest {
|
||||
string adapter_id = 1;
|
||||
string adapter_path = 2;
|
||||
int32 rank = 3;
|
||||
}
|
||||
|
||||
message LoadLoRAResponse {
|
||||
bool success = 1;
|
||||
string adapter_id = 2;
|
||||
string message = 3;
|
||||
}
|
||||
|
||||
// Unload LoRA adapter
|
||||
message UnloadLoRARequest {
|
||||
string adapter_id = 1;
|
||||
}
|
||||
|
||||
message UnloadLoRAResponse {
|
||||
bool success = 1;
|
||||
string message = 2;
|
||||
}
|
||||
|
||||
// Update weights
|
||||
message UpdateWeightsRequest {
|
||||
oneof source {
|
||||
string disk_path = 1;
|
||||
bytes tensor_data = 2;
|
||||
string remote_url = 3;
|
||||
}
|
||||
string weight_name = 4;
|
||||
}
|
||||
|
||||
message UpdateWeightsResponse {
|
||||
bool success = 1;
|
||||
string message = 2;
|
||||
}
|
||||
|
||||
// Get internal state for debugging
|
||||
message GetInternalStateRequest {
|
||||
repeated string state_keys = 1;
|
||||
}
|
||||
|
||||
message GetInternalStateResponse {
|
||||
google.protobuf.Struct state = 1;
|
||||
}
|
||||
|
||||
// Set internal state for testing
|
||||
message SetInternalStateRequest {
|
||||
google.protobuf.Struct state = 1;
|
||||
}
|
||||
|
||||
message SetInternalStateResponse {
|
||||
bool success = 1;
|
||||
string message = 2;
|
||||
}
|
||||
Reference in New Issue
Block a user