390 lines
8.0 KiB
Protocol Buffer
390 lines
8.0 KiB
Protocol Buffer
syntax = "proto3";
|
|
|
|
package sglang.grpc.scheduler;
|
|
|
|
import "google/protobuf/timestamp.proto";
|
|
import "google/protobuf/struct.proto";
|
|
|
|
// Service definition for SGLang scheduler communication
|
|
// This protocol bridges the Rust router and Python scheduler
|
|
service SglangScheduler {
|
|
// Submit a generation request (supports streaming)
|
|
rpc Generate(GenerateRequest) returns (stream GenerateResponse);
|
|
|
|
// Submit an embedding request
|
|
rpc Embed(EmbedRequest) returns (EmbedResponse);
|
|
|
|
// Health check and metrics
|
|
rpc HealthCheck(HealthCheckRequest) returns (HealthCheckResponse);
|
|
|
|
// Abort a running request
|
|
rpc Abort(AbortRequest) returns (AbortResponse);
|
|
|
|
}
|
|
|
|
// =====================
|
|
// Common Types
|
|
// =====================
|
|
|
|
// Sampling parameters matching SGLang's SamplingParams
|
|
message SamplingParams {
|
|
float temperature = 1;
|
|
float top_p = 2;
|
|
int32 top_k = 3;
|
|
float min_p = 4;
|
|
float frequency_penalty = 5;
|
|
float presence_penalty = 6;
|
|
float repetition_penalty = 7;
|
|
|
|
int32 max_new_tokens = 8;
|
|
repeated string stop = 9;
|
|
repeated int32 stop_token_ids = 10;
|
|
bool skip_special_tokens = 11;
|
|
bool spaces_between_special_tokens = 12;
|
|
|
|
// Structured generation
|
|
oneof constraint {
|
|
string regex = 13;
|
|
string json_schema = 14;
|
|
string ebnf_grammar = 15;
|
|
}
|
|
|
|
// LoRA adapter
|
|
string lora_path = 16;
|
|
|
|
// Speculative decoding
|
|
int32 n = 17; // Number of samples
|
|
|
|
// Token healing
|
|
bool token_healing = 18;
|
|
|
|
// Additional parameters
|
|
int32 min_new_tokens = 19;
|
|
bool ignore_eos = 20;
|
|
bool no_stop_trim = 21;
|
|
int32 stream_interval = 22;
|
|
map<string, float> logit_bias = 23;
|
|
string structural_tag = 24;
|
|
|
|
// Custom parameters for extensibility
|
|
google.protobuf.Struct custom_params = 25;
|
|
}
|
|
|
|
|
|
// Disaggregated serving parameters
|
|
message DisaggregatedParams {
|
|
string bootstrap_host = 1;
|
|
int32 bootstrap_port = 2;
|
|
int32 bootstrap_room = 3;
|
|
}
|
|
|
|
// =====================
|
|
// Generate Request
|
|
// =====================
|
|
|
|
message GenerateRequest {
|
|
string request_id = 1;
|
|
|
|
// Input must be tokenized (no raw text)
|
|
TokenizedInput tokenized = 2;
|
|
|
|
// Multimodal inputs
|
|
MultimodalInputs mm_inputs = 3;
|
|
|
|
// Generation parameters
|
|
SamplingParams sampling_params = 4;
|
|
|
|
// Return options
|
|
bool return_logprob = 5;
|
|
int32 logprob_start_len = 6;
|
|
int32 top_logprobs_num = 7;
|
|
repeated int32 token_ids_logprob = 8;
|
|
bool return_hidden_states = 9;
|
|
|
|
// For disaggregated serving
|
|
DisaggregatedParams disaggregated_params = 10;
|
|
|
|
// Custom logit processor (serialized)
|
|
string custom_logit_processor = 11;
|
|
|
|
// Request metadata
|
|
google.protobuf.Timestamp timestamp = 12;
|
|
bool log_metrics = 13;
|
|
|
|
// Input embeddings (alternative to text/tokens)
|
|
repeated float input_embeds = 14;
|
|
|
|
// LoRA adapter ID (if pre-loaded)
|
|
string lora_id = 15;
|
|
|
|
// Data parallel routing
|
|
int32 data_parallel_rank = 16;
|
|
|
|
// For load balancing
|
|
int32 dp_balance_id = 17;
|
|
}
|
|
|
|
message TokenizedInput {
|
|
string original_text = 1; // For reference
|
|
repeated int32 input_ids = 2;
|
|
}
|
|
|
|
message MultimodalInputs {
|
|
// Simplified multimodal handling - actual data processed by tokenizer
|
|
repeated string image_urls = 1;
|
|
repeated string video_urls = 2;
|
|
repeated string audio_urls = 3;
|
|
|
|
// Pre-processed multimodal features (if available)
|
|
google.protobuf.Struct processed_features = 4;
|
|
|
|
// Raw data for direct processing
|
|
repeated bytes image_data = 5;
|
|
repeated bytes video_data = 6;
|
|
repeated bytes audio_data = 7;
|
|
|
|
// Modality metadata
|
|
repeated string modalities = 8;
|
|
}
|
|
|
|
// =====================
|
|
// Generate Response
|
|
// =====================
|
|
|
|
message GenerateResponse {
|
|
string request_id = 1;
|
|
|
|
// Response type
|
|
oneof response {
|
|
GenerateStreamChunk chunk = 2;
|
|
GenerateComplete complete = 3;
|
|
GenerateError error = 4;
|
|
}
|
|
}
|
|
|
|
message GenerateStreamChunk {
|
|
// Generated token
|
|
int32 token_id = 1;
|
|
string text = 2;
|
|
|
|
// Cumulative counts
|
|
int32 prompt_tokens = 3;
|
|
int32 completion_tokens = 4;
|
|
int32 cached_tokens = 5;
|
|
|
|
// Logprobs (if requested)
|
|
LogProbs logprobs = 6;
|
|
|
|
// Hidden states (if requested)
|
|
repeated float hidden_states = 7;
|
|
|
|
// Metadata
|
|
float generation_time = 8; // Time to generate this token
|
|
int32 queue_time = 9; // Time spent in queue
|
|
}
|
|
|
|
message GenerateComplete {
|
|
// Final output
|
|
repeated int32 output_ids = 1;
|
|
string output_text = 2;
|
|
|
|
// Finish reason
|
|
enum FinishReason {
|
|
// The model generated a stop sequence.
|
|
STOP = 0;
|
|
// The model reached the maximum generation length.
|
|
LENGTH = 1;
|
|
// The model generated an end-of-sequence (EOS) token.
|
|
EOS_TOKEN = 2;
|
|
// The model generated a user-provided stop string.
|
|
STOP_STR = 3;
|
|
// The request was aborted by the user or system.
|
|
ABORT = 4;
|
|
}
|
|
FinishReason finish_reason = 3;
|
|
|
|
// All logprobs if requested
|
|
repeated LogProbs all_logprobs = 11;
|
|
|
|
// All hidden states if requested
|
|
repeated HiddenStates all_hidden_states = 12;
|
|
}
|
|
|
|
message GenerateError {
|
|
string message = 1;
|
|
string http_status_code = 2;
|
|
string details = 3;
|
|
}
|
|
|
|
message LogProbs {
|
|
repeated float token_logprobs = 1;
|
|
repeated int32 token_ids = 2;
|
|
|
|
// Top logprobs at each position
|
|
repeated TopLogProbs top_logprobs = 3;
|
|
|
|
// Decoded text for tokens
|
|
repeated string token_texts = 4;
|
|
}
|
|
|
|
message TopLogProbs {
|
|
repeated float values = 1;
|
|
repeated int32 token_ids = 2;
|
|
repeated string token_texts = 3;
|
|
}
|
|
|
|
message HiddenStates {
|
|
repeated float values = 1;
|
|
int32 layer = 2;
|
|
int32 position = 3;
|
|
}
|
|
|
|
// =====================
|
|
// Embedding Request
|
|
// =====================
|
|
|
|
message EmbedRequest {
|
|
string request_id = 1;
|
|
|
|
// Input must be tokenized (no raw text)
|
|
TokenizedInput tokenized = 2;
|
|
|
|
// Multimodal inputs
|
|
MultimodalInputs mm_inputs = 4;
|
|
|
|
// Dummy sampling params for compatibility
|
|
// EmbedRequest doesn't use sampling_params
|
|
SamplingParams sampling_params = 5;
|
|
|
|
bool log_metrics = 6;
|
|
|
|
// Token type IDs for models that require them
|
|
repeated int32 token_type_ids = 7;
|
|
|
|
// Data parallel routing
|
|
int32 data_parallel_rank = 8;
|
|
|
|
// For cross-encoder requests
|
|
bool is_cross_encoder = 9;
|
|
repeated string texts = 10; // For cross-encoder batch
|
|
}
|
|
|
|
message EmbedResponse {
|
|
string request_id = 1;
|
|
|
|
oneof response {
|
|
EmbedComplete complete = 2;
|
|
EmbedError error = 3;
|
|
}
|
|
}
|
|
|
|
message EmbedComplete {
|
|
repeated float embedding = 1;
|
|
int32 prompt_tokens = 2;
|
|
int32 cached_tokens = 3;
|
|
|
|
// Additional metadata
|
|
int32 embedding_dim = 4;
|
|
float generation_time = 5;
|
|
|
|
// For batch embeddings
|
|
repeated Embedding batch_embeddings = 6;
|
|
}
|
|
|
|
message Embedding {
|
|
repeated float values = 1;
|
|
int32 index = 2;
|
|
}
|
|
|
|
message EmbedError {
|
|
string message = 1;
|
|
string code = 2;
|
|
string details = 3;
|
|
}
|
|
|
|
// =====================
|
|
// Management Operations
|
|
// =====================
|
|
|
|
message HealthCheckRequest {
|
|
// Input for health test generation (must be tokenized)
|
|
TokenizedInput tokenized = 1;
|
|
}
|
|
|
|
message HealthCheckResponse {
|
|
bool healthy = 1;
|
|
string message = 2;
|
|
}
|
|
|
|
message AbortRequest {
|
|
string request_id = 1;
|
|
string reason = 2;
|
|
}
|
|
|
|
message AbortResponse {
|
|
bool success = 1;
|
|
string message = 2;
|
|
}
|
|
|
|
|
|
// =====================
|
|
// Additional Operations (Future)
|
|
// =====================
|
|
|
|
// Load LoRA adapter
|
|
message LoadLoRARequest {
|
|
string adapter_id = 1;
|
|
string adapter_path = 2;
|
|
int32 rank = 3;
|
|
}
|
|
|
|
message LoadLoRAResponse {
|
|
bool success = 1;
|
|
string adapter_id = 2;
|
|
string message = 3;
|
|
}
|
|
|
|
// Unload LoRA adapter
|
|
message UnloadLoRARequest {
|
|
string adapter_id = 1;
|
|
}
|
|
|
|
message UnloadLoRAResponse {
|
|
bool success = 1;
|
|
string message = 2;
|
|
}
|
|
|
|
// Update weights
|
|
message UpdateWeightsRequest {
|
|
oneof source {
|
|
string disk_path = 1;
|
|
bytes tensor_data = 2;
|
|
string remote_url = 3;
|
|
}
|
|
string weight_name = 4;
|
|
}
|
|
|
|
message UpdateWeightsResponse {
|
|
bool success = 1;
|
|
string message = 2;
|
|
}
|
|
|
|
// Get internal state for debugging
|
|
message GetInternalStateRequest {
|
|
repeated string state_keys = 1;
|
|
}
|
|
|
|
message GetInternalStateResponse {
|
|
google.protobuf.Struct state = 1;
|
|
}
|
|
|
|
// Set internal state for testing
|
|
message SetInternalStateRequest {
|
|
google.protobuf.Struct state = 1;
|
|
}
|
|
|
|
message SetInternalStateResponse {
|
|
bool success = 1;
|
|
string message = 2;
|
|
}
|