diff --git a/sgl-router/src/proto/sglang_scheduler.proto b/sgl-router/src/proto/sglang_scheduler.proto new file mode 100644 index 000000000..be8bb09eb --- /dev/null +++ b/sgl-router/src/proto/sglang_scheduler.proto @@ -0,0 +1,541 @@ +syntax = "proto3"; + +package sglang.grpc.scheduler; + +import "google/protobuf/timestamp.proto"; +import "google/protobuf/struct.proto"; + +// Service definition for SGLang scheduler communication +// This protocol bridges the Rust router and Python scheduler +service SGLangScheduler { + // Initialize connection and get model info + rpc Initialize(InitializeRequest) returns (InitializeResponse); + + // Submit a generation request (supports streaming) + rpc Generate(GenerateRequest) returns (stream GenerateResponse); + + // Submit an embedding request + rpc Embed(EmbedRequest) returns (EmbedResponse); + + // Health check and metrics + rpc HealthCheck(HealthCheckRequest) returns (HealthCheckResponse); + + // Abort a running request + rpc AbortRequest(AbortRequest) returns (AbortResponse); + + // Flush KV cache + rpc FlushCache(FlushCacheRequest) returns (FlushCacheResponse); +} + +// ===================== +// Common Types +// ===================== + +// Sampling parameters matching SGLang's SamplingParams +message SamplingParams { + float temperature = 1; + float top_p = 2; + int32 top_k = 3; + float min_p = 4; + float frequency_penalty = 5; + float presence_penalty = 6; + float repetition_penalty = 7; + + int32 max_new_tokens = 8; + repeated string stop = 9; + repeated int32 stop_token_ids = 10; + bool skip_special_tokens = 11; + bool spaces_between_special_tokens = 12; + + // Structured generation + oneof constraint { + string regex = 13; + string json_schema = 14; + string ebnf_grammar = 15; + } + + // LoRA adapter + string lora_path = 16; + + // Speculative decoding + int32 n = 17; // Number of samples + + // Token healing + bool token_healing = 18; + + // Additional parameters + int32 min_new_tokens = 19; + bool ignore_eos = 20; + bool no_stop_trim = 21; + int32 stream_interval = 22; + map logit_bias = 23; + string structural_tag = 24; + + // Custom parameters for extensibility + google.protobuf.Struct custom_params = 25; +} + +// Session parameters for continual prompting +message SessionParams { + string session_id = 1; + string request_id = 2; + int32 offset = 3; + bool replace = 4; + bool drop_previous_output = 5; +} + +// Disaggregated serving parameters +message DisaggregatedParams { + string bootstrap_host = 1; + int32 bootstrap_port = 2; + int32 bootstrap_room = 3; +} + +// ===================== +// Initialize +// ===================== + +message InitializeRequest { + string client_id = 1; + string client_version = 2; + + // Operating mode + enum Mode { + REGULAR = 0; // Normal mode with local scheduler + PREFILL = 1; // Prefill-only mode for disaggregated serving + DECODE = 2; // Decode-only mode for disaggregated serving + } + Mode mode = 3; +} + +message InitializeResponse { + bool success = 1; + string scheduler_version = 2; + + // Model information + ModelInfo model_info = 3; + + // Server capabilities + ServerCapabilities capabilities = 4; + + // Error message if success is false + string error_message = 5; +} + +message ModelInfo { + string model_name = 1; + int32 max_context_length = 2; + int32 vocab_size = 3; + bool supports_tool_calling = 4; + bool supports_vision = 5; + repeated string special_tokens = 6; + + // Additional model metadata + string model_type = 7; + int32 num_layers = 8; + int32 hidden_size = 9; + int32 num_attention_heads = 10; + int32 num_key_value_heads = 11; + + // Tokenizer info + string tokenizer_type = 12; + repeated int32 eos_token_ids = 13; + int32 pad_token_id = 14; + int32 bos_token_id = 15; +} + +message ServerCapabilities { + bool continuous_batching = 1; + bool disaggregated_serving = 2; + bool speculative_decoding = 3; + int32 max_batch_size = 4; + int32 max_num_batched_tokens = 5; + int32 max_prefill_tokens = 6; + string attention_backend = 7; // "flashinfer", "triton", "torch" + + // Additional capabilities + bool supports_lora = 8; + bool supports_grammar = 9; + bool supports_multimodal = 10; + repeated string supported_modalities = 11; // ["image", "video", "audio"] + bool supports_custom_logit_processor = 12; + bool supports_session = 13; + + // Hardware info + int32 num_gpus = 14; + string gpu_type = 15; + int64 total_gpu_memory = 16; + + // Parallelism info + int32 tensor_parallel_size = 17; + int32 pipeline_parallel_size = 18; + int32 data_parallel_size = 19; +} + +// ===================== +// Generate Request +// ===================== + +message GenerateRequest { + string request_id = 1; + + // Input can be either text or tokenized + oneof input { + string text = 2; + TokenizedInput tokenized = 3; + } + + // Multimodal inputs + MultimodalInputs mm_inputs = 4; + + // Generation parameters + SamplingParams sampling_params = 5; + + // Return options + bool return_logprob = 6; + int32 logprob_start_len = 7; + int32 top_logprobs_num = 8; + repeated int32 token_ids_logprob = 9; + bool return_hidden_states = 10; + + // Session management + SessionParams session_params = 11; + + // For disaggregated serving + DisaggregatedParams disaggregated_params = 12; + + // Custom logit processor (serialized) + string custom_logit_processor = 13; + + // Request metadata + google.protobuf.Timestamp timestamp = 14; + bool log_metrics = 15; + + // Input embeddings (alternative to text/tokens) + repeated float input_embeds = 16; + + // LoRA adapter ID (if pre-loaded) + string lora_id = 17; + + // Data parallel routing + int32 data_parallel_rank = 18; + + // For load balancing + int32 dp_balance_id = 19; +} + +message TokenizedInput { + string original_text = 1; // For reference + repeated int32 input_ids = 2; +} + +message MultimodalInputs { + // Simplified multimodal handling - actual data processed by tokenizer + repeated string image_urls = 1; + repeated string video_urls = 2; + repeated string audio_urls = 3; + + // Pre-processed multimodal features (if available) + google.protobuf.Struct processed_features = 4; + + // Raw data for direct processing + repeated bytes image_data = 5; + repeated bytes video_data = 6; + repeated bytes audio_data = 7; + + // Modality metadata + repeated string modalities = 8; +} + +// ===================== +// Generate Response +// ===================== + +message GenerateResponse { + string request_id = 1; + + // Response type + oneof response { + GenerateStreamChunk chunk = 2; + GenerateComplete complete = 3; + GenerateError error = 4; + } +} + +message GenerateStreamChunk { + // Generated token + int32 token_id = 1; + string text = 2; + + // Cumulative counts + int32 prompt_tokens = 3; + int32 completion_tokens = 4; + int32 cached_tokens = 5; + + // Logprobs (if requested) + LogProbs logprobs = 6; + + // Hidden states (if requested) + repeated float hidden_states = 7; + + // Metadata + float generation_time = 8; // Time to generate this token + int32 queue_time = 9; // Time spent in queue +} + +message GenerateComplete { + // Final output + repeated int32 output_ids = 1; + string output_text = 2; + + // Finish reason + enum FinishReason { + // The model generated a stop sequence. + STOP = 0; + // The model reached the maximum generation length. + LENGTH = 1; + // The model generated an end-of-sequence (EOS) token. + EOS_TOKEN = 2; + // The model generated a user-provided stop string. + STOP_STR = 3; + // The request was aborted by the user or system. + ABORT = 4; + } + FinishReason finish_reason = 3; + + // Final counts + int32 prompt_tokens = 4; + int32 completion_tokens = 5; + int32 cached_tokens = 6; + + // Performance metrics + float total_generation_time = 7; + float time_to_first_token = 8; + float tokens_per_second = 9; + + // Spec decode metrics + int32 spec_verify_count = 10; + + // All logprobs if requested + repeated LogProbs all_logprobs = 11; + + // All hidden states if requested + repeated HiddenStates all_hidden_states = 12; +} + +message GenerateError { + string message = 1; + string http_status_code = 2; + string details = 3; +} + +message LogProbs { + repeated float token_logprobs = 1; + repeated int32 token_ids = 2; + + // Top logprobs at each position + repeated TopLogProbs top_logprobs = 3; + + // Decoded text for tokens + repeated string token_texts = 4; +} + +message TopLogProbs { + repeated float values = 1; + repeated int32 token_ids = 2; + repeated string token_texts = 3; +} + +message HiddenStates { + repeated float values = 1; + int32 layer = 2; + int32 position = 3; +} + +// ===================== +// Embedding Request +// ===================== + +message EmbedRequest { + string request_id = 1; + + oneof input { + string text = 2; + TokenizedInput tokenized = 3; + } + + // Multimodal inputs + MultimodalInputs mm_inputs = 4; + + // Dummy sampling params for compatibility + // EmbedRequest doesn't use sampling_params + SamplingParams sampling_params = 5; + + bool log_metrics = 6; + + // Token type IDs for models that require them + repeated int32 token_type_ids = 7; + + // Data parallel routing + int32 data_parallel_rank = 8; + + // For cross-encoder requests + bool is_cross_encoder = 9; + repeated string texts = 10; // For cross-encoder batch +} + +message EmbedResponse { + string request_id = 1; + + oneof response { + EmbedComplete complete = 2; + EmbedError error = 3; + } +} + +message EmbedComplete { + repeated float embedding = 1; + int32 prompt_tokens = 2; + int32 cached_tokens = 3; + + // Additional metadata + int32 embedding_dim = 4; + float generation_time = 5; + + // For batch embeddings + repeated Embedding batch_embeddings = 6; +} + +message Embedding { + repeated float values = 1; + int32 index = 2; +} + +message EmbedError { + string message = 1; + string code = 2; + string details = 3; +} + +// ===================== +// Management Operations +// ===================== + +message HealthCheckRequest { + bool include_detailed_metrics = 1; +} + +message HealthCheckResponse { + bool healthy = 1; + + // Current load metrics + int32 num_requests_running = 2; + int32 num_requests_waiting = 3; + float gpu_cache_usage = 4; + float gpu_memory_usage = 5; + + // KV cache metrics + int32 kv_cache_total_blocks = 6; + int32 kv_cache_used_blocks = 7; + float kv_cache_hit_rate = 8; + + // Additional metrics + int32 num_grammar_queue_requests = 9; + float generation_throughput = 10; // tokens/sec + float average_queue_time = 11; // seconds + float average_generation_time = 12; // seconds + + // System metrics + float cpu_usage = 13; + int64 memory_usage = 14; + + // Disaggregation metrics + int32 num_prefill_requests = 15; + int32 num_decode_requests = 16; + + // Detailed metrics (optional) + google.protobuf.Struct detailed_metrics = 17; +} + +message AbortRequest { + string request_id = 1; + string reason = 2; +} + +message AbortResponse { + bool success = 1; + string message = 2; +} + +message FlushCacheRequest { + bool flush_all = 1; + repeated string session_ids = 2; // Flush specific sessions +} + +message FlushCacheResponse { + bool success = 1; + int32 num_entries_flushed = 2; + int64 memory_freed = 3; // bytes + string message = 4; +} + +// ===================== +// Additional Operations (Future) +// ===================== + +// Load LoRA adapter +message LoadLoRARequest { + string adapter_id = 1; + string adapter_path = 2; + int32 rank = 3; +} + +message LoadLoRAResponse { + bool success = 1; + string adapter_id = 2; + string message = 3; +} + +// Unload LoRA adapter +message UnloadLoRARequest { + string adapter_id = 1; +} + +message UnloadLoRAResponse { + bool success = 1; + string message = 2; +} + +// Update weights +message UpdateWeightsRequest { + oneof source { + string disk_path = 1; + bytes tensor_data = 2; + string remote_url = 3; + } + string weight_name = 4; +} + +message UpdateWeightsResponse { + bool success = 1; + string message = 2; +} + +// Get internal state for debugging +message GetInternalStateRequest { + repeated string state_keys = 1; +} + +message GetInternalStateResponse { + google.protobuf.Struct state = 1; +} + +// Set internal state for testing +message SetInternalStateRequest { + google.protobuf.Struct state = 1; +} + +message SetInternalStateResponse { + bool success = 1; + string message = 2; +}