Implement Standalone gRPC Server for SGLang Python Scheduler (#10283)
This commit is contained in:
@@ -37,21 +37,6 @@ impl SglangSchedulerClient {
|
||||
Ok(Self { client })
|
||||
}
|
||||
|
||||
/// Initialize the connection
|
||||
pub async fn initialize(
|
||||
&mut self,
|
||||
client_id: String,
|
||||
) -> Result<proto::InitializeResponse, Box<dyn std::error::Error>> {
|
||||
let request = Request::new(proto::InitializeRequest {
|
||||
client_id,
|
||||
client_version: "0.1.0".to_string(),
|
||||
mode: proto::initialize_request::Mode::Regular as i32,
|
||||
});
|
||||
|
||||
let response = self.client.initialize(request).await?;
|
||||
Ok(response.into_inner())
|
||||
}
|
||||
|
||||
/// Submit a generation request (returns streaming response)
|
||||
pub async fn generate_stream(
|
||||
&mut self,
|
||||
@@ -68,7 +53,10 @@ impl SglangSchedulerClient {
|
||||
) -> Result<proto::HealthCheckResponse, Box<dyn std::error::Error>> {
|
||||
debug!("Sending health check request");
|
||||
let request = Request::new(proto::HealthCheckRequest {
|
||||
include_detailed_metrics: false,
|
||||
tokenized: Some(proto::TokenizedInput {
|
||||
original_text: "Hello".to_string(),
|
||||
input_ids: vec![9906], // Mock token ID for "Hello"
|
||||
}),
|
||||
});
|
||||
|
||||
let response = self.client.health_check(request).await?;
|
||||
@@ -87,21 +75,6 @@ impl SglangSchedulerClient {
|
||||
self.client.abort(request).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Flush cache
|
||||
pub async fn flush_cache(
|
||||
&mut self,
|
||||
flush_all: bool,
|
||||
session_ids: &[String],
|
||||
) -> Result<proto::FlushCacheResponse, Box<dyn std::error::Error>> {
|
||||
let request = Request::new(proto::FlushCacheRequest {
|
||||
flush_all,
|
||||
session_ids: session_ids.to_vec(),
|
||||
});
|
||||
|
||||
let response = self.client.flush_cache(request).await?;
|
||||
Ok(response.into_inner())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
@@ -111,14 +84,13 @@ mod tests {
|
||||
#[test]
|
||||
fn test_proto_types_compilation() {
|
||||
// Test that protobuf types can be constructed
|
||||
let init_req = proto::InitializeRequest {
|
||||
client_id: "test-client".to_string(),
|
||||
client_version: "0.1.0".to_string(),
|
||||
mode: 0,
|
||||
let health_req = proto::HealthCheckRequest {
|
||||
tokenized: Some(proto::TokenizedInput {
|
||||
original_text: "test".to_string(),
|
||||
input_ids: vec![1296],
|
||||
}),
|
||||
};
|
||||
assert_eq!(init_req.client_id, "test-client");
|
||||
assert_eq!(init_req.client_version, "0.1.0");
|
||||
assert_eq!(init_req.mode, 0);
|
||||
assert!(health_req.tokenized.is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -134,9 +106,10 @@ mod tests {
|
||||
|
||||
let gen_req = proto::GenerateRequest {
|
||||
request_id: "test-req-123".to_string(),
|
||||
input: Some(proto::generate_request::Input::Text(
|
||||
"Hello world".to_string(),
|
||||
)),
|
||||
tokenized: Some(proto::TokenizedInput {
|
||||
original_text: "Hello world".to_string(),
|
||||
input_ids: vec![9906, 1917], // Mock token IDs for "Hello world"
|
||||
}),
|
||||
sampling_params: Some(sampling_params),
|
||||
return_logprob: true,
|
||||
logprob_start_len: 0,
|
||||
@@ -145,8 +118,8 @@ mod tests {
|
||||
};
|
||||
|
||||
assert_eq!(gen_req.request_id, "test-req-123");
|
||||
if let Some(proto::generate_request::Input::Text(text)) = &gen_req.input {
|
||||
assert_eq!(text, "Hello world");
|
||||
if let Some(ref tokenized) = &gen_req.tokenized {
|
||||
assert_eq!(tokenized.original_text, "Hello world");
|
||||
}
|
||||
assert!(gen_req.return_logprob);
|
||||
assert_eq!(gen_req.top_logprobs_num, 5);
|
||||
@@ -160,9 +133,12 @@ mod tests {
|
||||
#[test]
|
||||
fn test_health_check_request() {
|
||||
let health_req = proto::HealthCheckRequest {
|
||||
include_detailed_metrics: true,
|
||||
tokenized: Some(proto::TokenizedInput {
|
||||
original_text: "test".to_string(),
|
||||
input_ids: vec![1296], // Mock token ID for "test"
|
||||
}),
|
||||
};
|
||||
assert!(health_req.include_detailed_metrics);
|
||||
assert!(health_req.tokenized.is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -175,17 +151,6 @@ mod tests {
|
||||
assert_eq!(abort_req.reason, "User canceled");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_flush_cache_request() {
|
||||
let flush_req = proto::FlushCacheRequest {
|
||||
flush_all: true,
|
||||
session_ids: vec!["session1".to_string(), "session2".to_string()],
|
||||
};
|
||||
assert!(flush_req.flush_all);
|
||||
assert_eq!(flush_req.session_ids.len(), 2);
|
||||
assert_eq!(flush_req.session_ids[0], "session1");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sampling_params_defaults() {
|
||||
let params = proto::SamplingParams::default();
|
||||
@@ -214,38 +179,29 @@ mod tests {
|
||||
assert_eq!(mm_inputs.modalities[0], "image");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_session_params() {
|
||||
let session_params = proto::SessionParams {
|
||||
session_id: "sess-789".to_string(),
|
||||
request_id: "req-101".to_string(),
|
||||
offset: 100,
|
||||
replace: true,
|
||||
drop_previous_output: false,
|
||||
};
|
||||
|
||||
assert_eq!(session_params.session_id, "sess-789");
|
||||
assert_eq!(session_params.request_id, "req-101");
|
||||
assert_eq!(session_params.offset, 100);
|
||||
assert!(session_params.replace);
|
||||
assert!(!session_params.drop_previous_output);
|
||||
}
|
||||
// TODO: SessionParams not in current proto - skip test
|
||||
// #[test]
|
||||
// fn test_session_params() { ... }
|
||||
|
||||
#[test]
|
||||
fn test_embed_request() {
|
||||
let embed_req = proto::EmbedRequest {
|
||||
request_id: "embed-req-202".to_string(),
|
||||
input: Some(proto::embed_request::Input::Text(
|
||||
"This is a test sentence for embedding".to_string(),
|
||||
)),
|
||||
tokenized: Some(proto::TokenizedInput {
|
||||
original_text: "This is a test sentence for embedding".to_string(),
|
||||
input_ids: vec![2028, 374, 264, 1296, 11914, 369, 28537], // Mock token IDs
|
||||
}),
|
||||
log_metrics: true,
|
||||
data_parallel_rank: 0,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
assert_eq!(embed_req.request_id, "embed-req-202");
|
||||
if let Some(proto::embed_request::Input::Text(text)) = &embed_req.input {
|
||||
assert_eq!(text, "This is a test sentence for embedding");
|
||||
if let Some(ref tokenized) = &embed_req.tokenized {
|
||||
assert_eq!(
|
||||
tokenized.original_text,
|
||||
"This is a test sentence for embedding"
|
||||
);
|
||||
}
|
||||
assert!(embed_req.log_metrics);
|
||||
assert_eq!(embed_req.data_parallel_rank, 0);
|
||||
@@ -292,36 +248,7 @@ mod tests {
|
||||
assert_eq!(chunk.queue_time, 10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_model_info() {
|
||||
let model_info = proto::ModelInfo {
|
||||
model_name: "Meta-Llama-3-8B-Instruct".to_string(),
|
||||
max_context_length: 8192,
|
||||
vocab_size: 128256,
|
||||
supports_tool_calling: true,
|
||||
supports_vision: false,
|
||||
special_tokens: vec![
|
||||
"<|begin_of_text|>".to_string(),
|
||||
"<|end_of_text|>".to_string(),
|
||||
],
|
||||
model_type: "llama".to_string(),
|
||||
num_layers: 32,
|
||||
hidden_size: 4096,
|
||||
num_attention_heads: 32,
|
||||
num_key_value_heads: 8,
|
||||
tokenizer_type: "llama".to_string(),
|
||||
eos_token_ids: vec![128001, 128009],
|
||||
pad_token_id: 128001,
|
||||
bos_token_id: 128000,
|
||||
};
|
||||
|
||||
assert_eq!(model_info.model_name, "Meta-Llama-3-8B-Instruct");
|
||||
assert_eq!(model_info.max_context_length, 8192);
|
||||
assert_eq!(model_info.vocab_size, 128256);
|
||||
assert!(model_info.supports_tool_calling);
|
||||
assert!(!model_info.supports_vision);
|
||||
assert_eq!(model_info.special_tokens.len(), 2);
|
||||
assert_eq!(model_info.num_layers, 32);
|
||||
assert_eq!(model_info.eos_token_ids, vec![128001, 128009]);
|
||||
}
|
||||
// TODO: ModelInfo not in current proto - skip test
|
||||
// #[test]
|
||||
// fn test_model_info() { ... }
|
||||
}
|
||||
|
||||
@@ -8,9 +8,6 @@ import "google/protobuf/struct.proto";
|
||||
// Service definition for SGLang scheduler communication
|
||||
// This protocol bridges the Rust router and Python scheduler
|
||||
service SglangScheduler {
|
||||
// Initialize connection and get model info
|
||||
rpc Initialize(InitializeRequest) returns (InitializeResponse);
|
||||
|
||||
// Submit a generation request (supports streaming)
|
||||
rpc Generate(GenerateRequest) returns (stream GenerateResponse);
|
||||
|
||||
@@ -23,8 +20,6 @@ service SglangScheduler {
|
||||
// Abort a running request
|
||||
rpc Abort(AbortRequest) returns (AbortResponse);
|
||||
|
||||
// Flush KV cache
|
||||
rpc FlushCache(FlushCacheRequest) returns (FlushCacheResponse);
|
||||
}
|
||||
|
||||
// =====================
|
||||
@@ -75,14 +70,6 @@ message SamplingParams {
|
||||
google.protobuf.Struct custom_params = 25;
|
||||
}
|
||||
|
||||
// Session parameters for continual prompting
|
||||
message SessionParams {
|
||||
string session_id = 1;
|
||||
string request_id = 2;
|
||||
int32 offset = 3;
|
||||
bool replace = 4;
|
||||
bool drop_previous_output = 5;
|
||||
}
|
||||
|
||||
// Disaggregated serving parameters
|
||||
message DisaggregatedParams {
|
||||
@@ -91,87 +78,6 @@ message DisaggregatedParams {
|
||||
int32 bootstrap_room = 3;
|
||||
}
|
||||
|
||||
// =====================
|
||||
// Initialize
|
||||
// =====================
|
||||
|
||||
message InitializeRequest {
|
||||
string client_id = 1;
|
||||
string client_version = 2;
|
||||
|
||||
// Operating mode
|
||||
enum Mode {
|
||||
REGULAR = 0; // Normal mode with local scheduler
|
||||
PREFILL = 1; // Prefill-only mode for disaggregated serving
|
||||
DECODE = 2; // Decode-only mode for disaggregated serving
|
||||
}
|
||||
Mode mode = 3;
|
||||
}
|
||||
|
||||
message InitializeResponse {
|
||||
bool success = 1;
|
||||
string scheduler_version = 2;
|
||||
|
||||
// Model information
|
||||
ModelInfo model_info = 3;
|
||||
|
||||
// Server capabilities
|
||||
ServerCapabilities capabilities = 4;
|
||||
|
||||
// Error message if success is false
|
||||
string error_message = 5;
|
||||
}
|
||||
|
||||
message ModelInfo {
|
||||
string model_name = 1;
|
||||
int32 max_context_length = 2;
|
||||
int32 vocab_size = 3;
|
||||
bool supports_tool_calling = 4;
|
||||
bool supports_vision = 5;
|
||||
repeated string special_tokens = 6;
|
||||
|
||||
// Additional model metadata
|
||||
string model_type = 7;
|
||||
int32 num_layers = 8;
|
||||
int32 hidden_size = 9;
|
||||
int32 num_attention_heads = 10;
|
||||
int32 num_key_value_heads = 11;
|
||||
|
||||
// Tokenizer info
|
||||
string tokenizer_type = 12;
|
||||
repeated int32 eos_token_ids = 13;
|
||||
int32 pad_token_id = 14;
|
||||
int32 bos_token_id = 15;
|
||||
}
|
||||
|
||||
message ServerCapabilities {
|
||||
bool continuous_batching = 1;
|
||||
bool disaggregated_serving = 2;
|
||||
bool speculative_decoding = 3;
|
||||
int32 max_batch_size = 4;
|
||||
int32 max_num_batched_tokens = 5;
|
||||
int32 max_prefill_tokens = 6;
|
||||
string attention_backend = 7; // "flashinfer", "triton", "torch"
|
||||
|
||||
// Additional capabilities
|
||||
bool supports_lora = 8;
|
||||
bool supports_grammar = 9;
|
||||
bool supports_multimodal = 10;
|
||||
repeated string supported_modalities = 11; // ["image", "video", "audio"]
|
||||
bool supports_custom_logit_processor = 12;
|
||||
bool supports_session = 13;
|
||||
|
||||
// Hardware info
|
||||
int32 num_gpus = 14;
|
||||
string gpu_type = 15;
|
||||
int64 total_gpu_memory = 16;
|
||||
|
||||
// Parallelism info
|
||||
int32 tensor_parallel_size = 17;
|
||||
int32 pipeline_parallel_size = 18;
|
||||
int32 data_parallel_size = 19;
|
||||
}
|
||||
|
||||
// =====================
|
||||
// Generate Request
|
||||
// =====================
|
||||
@@ -179,49 +85,43 @@ message ServerCapabilities {
|
||||
message GenerateRequest {
|
||||
string request_id = 1;
|
||||
|
||||
// Input can be either text or tokenized
|
||||
oneof input {
|
||||
string text = 2;
|
||||
TokenizedInput tokenized = 3;
|
||||
}
|
||||
// Input must be tokenized (no raw text)
|
||||
TokenizedInput tokenized = 2;
|
||||
|
||||
// Multimodal inputs
|
||||
MultimodalInputs mm_inputs = 4;
|
||||
MultimodalInputs mm_inputs = 3;
|
||||
|
||||
// Generation parameters
|
||||
SamplingParams sampling_params = 5;
|
||||
SamplingParams sampling_params = 4;
|
||||
|
||||
// Return options
|
||||
bool return_logprob = 6;
|
||||
int32 logprob_start_len = 7;
|
||||
int32 top_logprobs_num = 8;
|
||||
repeated int32 token_ids_logprob = 9;
|
||||
bool return_hidden_states = 10;
|
||||
|
||||
// Session management
|
||||
SessionParams session_params = 11;
|
||||
bool return_logprob = 5;
|
||||
int32 logprob_start_len = 6;
|
||||
int32 top_logprobs_num = 7;
|
||||
repeated int32 token_ids_logprob = 8;
|
||||
bool return_hidden_states = 9;
|
||||
|
||||
// For disaggregated serving
|
||||
DisaggregatedParams disaggregated_params = 12;
|
||||
DisaggregatedParams disaggregated_params = 10;
|
||||
|
||||
// Custom logit processor (serialized)
|
||||
string custom_logit_processor = 13;
|
||||
string custom_logit_processor = 11;
|
||||
|
||||
// Request metadata
|
||||
google.protobuf.Timestamp timestamp = 14;
|
||||
bool log_metrics = 15;
|
||||
google.protobuf.Timestamp timestamp = 12;
|
||||
bool log_metrics = 13;
|
||||
|
||||
// Input embeddings (alternative to text/tokens)
|
||||
repeated float input_embeds = 16;
|
||||
repeated float input_embeds = 14;
|
||||
|
||||
// LoRA adapter ID (if pre-loaded)
|
||||
string lora_id = 17;
|
||||
string lora_id = 15;
|
||||
|
||||
// Data parallel routing
|
||||
int32 data_parallel_rank = 18;
|
||||
int32 data_parallel_rank = 16;
|
||||
|
||||
// For load balancing
|
||||
int32 dp_balance_id = 19;
|
||||
int32 dp_balance_id = 17;
|
||||
}
|
||||
|
||||
message TokenizedInput {
|
||||
@@ -303,19 +203,6 @@ message GenerateComplete {
|
||||
}
|
||||
FinishReason finish_reason = 3;
|
||||
|
||||
// Final counts
|
||||
int32 prompt_tokens = 4;
|
||||
int32 completion_tokens = 5;
|
||||
int32 cached_tokens = 6;
|
||||
|
||||
// Performance metrics
|
||||
float total_generation_time = 7;
|
||||
float time_to_first_token = 8;
|
||||
float tokens_per_second = 9;
|
||||
|
||||
// Spec decode metrics
|
||||
int32 spec_verify_count = 10;
|
||||
|
||||
// All logprobs if requested
|
||||
repeated LogProbs all_logprobs = 11;
|
||||
|
||||
@@ -359,10 +246,8 @@ message HiddenStates {
|
||||
message EmbedRequest {
|
||||
string request_id = 1;
|
||||
|
||||
oneof input {
|
||||
string text = 2;
|
||||
TokenizedInput tokenized = 3;
|
||||
}
|
||||
// Input must be tokenized (no raw text)
|
||||
TokenizedInput tokenized = 2;
|
||||
|
||||
// Multimodal inputs
|
||||
MultimodalInputs mm_inputs = 4;
|
||||
@@ -422,39 +307,13 @@ message EmbedError {
|
||||
// =====================
|
||||
|
||||
message HealthCheckRequest {
|
||||
bool include_detailed_metrics = 1;
|
||||
// Input for health test generation (must be tokenized)
|
||||
TokenizedInput tokenized = 1;
|
||||
}
|
||||
|
||||
message HealthCheckResponse {
|
||||
bool healthy = 1;
|
||||
|
||||
// Current load metrics
|
||||
int32 num_requests_running = 2;
|
||||
int32 num_requests_waiting = 3;
|
||||
float gpu_cache_usage = 4;
|
||||
float gpu_memory_usage = 5;
|
||||
|
||||
// KV cache metrics
|
||||
int32 kv_cache_total_blocks = 6;
|
||||
int32 kv_cache_used_blocks = 7;
|
||||
float kv_cache_hit_rate = 8;
|
||||
|
||||
// Additional metrics
|
||||
int32 num_grammar_queue_requests = 9;
|
||||
float generation_throughput = 10; // tokens/sec
|
||||
float average_queue_time = 11; // seconds
|
||||
float average_generation_time = 12; // seconds
|
||||
|
||||
// System metrics
|
||||
float cpu_usage = 13;
|
||||
int64 memory_usage = 14;
|
||||
|
||||
// Disaggregation metrics
|
||||
int32 num_prefill_requests = 15;
|
||||
int32 num_decode_requests = 16;
|
||||
|
||||
// Detailed metrics (optional)
|
||||
google.protobuf.Struct detailed_metrics = 17;
|
||||
string message = 2;
|
||||
}
|
||||
|
||||
message AbortRequest {
|
||||
@@ -467,17 +326,6 @@ message AbortResponse {
|
||||
string message = 2;
|
||||
}
|
||||
|
||||
message FlushCacheRequest {
|
||||
bool flush_all = 1;
|
||||
repeated string session_ids = 2; // Flush specific sessions
|
||||
}
|
||||
|
||||
message FlushCacheResponse {
|
||||
bool success = 1;
|
||||
int32 num_entries_flushed = 2;
|
||||
int64 memory_freed = 3; // bytes
|
||||
string message = 4;
|
||||
}
|
||||
|
||||
// =====================
|
||||
// Additional Operations (Future)
|
||||
|
||||
Reference in New Issue
Block a user