[router] Implement gRPC SGLangSchedulerClient (#9364)
This commit is contained in:
2
.github/workflows/release-pypi-router.yml
vendored
2
.github/workflows/release-pypi-router.yml
vendored
@@ -47,7 +47,7 @@ jobs:
|
||||
env:
|
||||
CIBW_BUILD: "cp38-manylinux_x86_64 cp39-manylinux_x86_64 cp310-manylinux_x86_64 cp311-manylinux_x86_64 cp312-manylinux_x86_64"
|
||||
CIBW_BEFORE_ALL: |
|
||||
yum update && yum install -y openssl-devel && curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y
|
||||
yum update && yum install -y openssl-devel protobuf-compiler && curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y
|
||||
CIBW_ENVIRONMENT: "PATH=$HOME/.cargo/bin:$PATH"
|
||||
|
||||
- name: List built packages
|
||||
|
||||
@@ -39,13 +39,13 @@ ENV PATH="/root/.cargo/bin:${PATH}"
|
||||
|
||||
# install dependencies
|
||||
RUN apt update -y \
|
||||
&& apt install -y git build-essential libssl-dev pkg-config \
|
||||
&& apt install -y git build-essential libssl-dev pkg-config protobuf-compiler \
|
||||
&& rm -rf /var/lib/apt/lists/* \
|
||||
&& apt clean
|
||||
|
||||
# install rustup from rustup.rs
|
||||
RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y \
|
||||
&& rustc --version && cargo --version
|
||||
&& rustc --version && cargo --version && protoc --version
|
||||
|
||||
# pull the github repository
|
||||
RUN cd /opt \
|
||||
|
||||
@@ -4,10 +4,10 @@ set -euxo pipefail
|
||||
# Check if sudo is available
|
||||
if command -v sudo >/dev/null 2>&1; then
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y libssl-dev pkg-config
|
||||
sudo apt-get install -y libssl-dev pkg-config protobuf-compiler
|
||||
else
|
||||
apt-get update
|
||||
apt-get install -y libssl-dev pkg-config
|
||||
apt-get install -y libssl-dev pkg-config protobuf-compiler
|
||||
fi
|
||||
|
||||
# Install rustup (Rust installer and version manager)
|
||||
@@ -21,3 +21,4 @@ source $HOME/.cargo/env
|
||||
# Verify installation
|
||||
rustc --version
|
||||
cargo --version
|
||||
protoc --version
|
||||
|
||||
@@ -4,9 +4,11 @@ version = "0.0.0"
|
||||
edition = "2021"
|
||||
|
||||
[features]
|
||||
default = ["huggingface"]
|
||||
default = ["huggingface", "grpc-client"]
|
||||
huggingface = ["tokenizers"]
|
||||
tiktoken = ["tiktoken-rs"]
|
||||
grpc-client = []
|
||||
grpc-server = []
|
||||
|
||||
[lib]
|
||||
name = "sglang_router_rs"
|
||||
@@ -52,6 +54,18 @@ anyhow = "1.0"
|
||||
tokenizers = { version = "0.21.4", optional = true }
|
||||
tiktoken-rs = { version = "0.5", optional = true }
|
||||
|
||||
# gRPC and Protobuf dependencies
|
||||
tonic = { version = "0.12", features = ["tls", "gzip", "transport"] }
|
||||
prost = "0.13"
|
||||
prost-types = "0.13"
|
||||
deadpool = { version = "0.12", features = ["managed", "rt_tokio_1"] }
|
||||
backoff = { version = "0.4", features = ["tokio"] }
|
||||
strum = { version = "0.26", features = ["derive"] }
|
||||
|
||||
[build-dependencies]
|
||||
tonic-build = "0.12"
|
||||
prost-build = "0.13"
|
||||
|
||||
[dev-dependencies]
|
||||
criterion = { version = "0.5", features = ["html_reports"] }
|
||||
tower = { version = "0.5", features = ["util"] }
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
# Must include:
|
||||
include Cargo.toml # Rust project configuration
|
||||
include build.rs # Build script for protobuf generation
|
||||
recursive-include src *.rs # Rust source files
|
||||
recursive-include src/proto *.proto # Protobuf definitions
|
||||
|
||||
35
sgl-router/build.rs
Normal file
35
sgl-router/build.rs
Normal file
@@ -0,0 +1,35 @@
|
||||
fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
// Only regenerate if the proto file changes
|
||||
println!("cargo:rerun-if-changed=src/proto/sglang_scheduler.proto");
|
||||
|
||||
// Configure protobuf compilation with custom settings
|
||||
let config = prost_build::Config::new();
|
||||
|
||||
// Skip serde for types that use prost_types::Struct
|
||||
// These cause conflicts and we don't need serde for all generated types
|
||||
|
||||
// Configure tonic-build for gRPC code generation
|
||||
tonic_build::configure()
|
||||
// Generate both client and server code
|
||||
.build_server(true)
|
||||
.build_client(true)
|
||||
// Add a module-level attribute for documentation and clippy warnings
|
||||
.server_mod_attribute(
|
||||
"sglang.grpc.scheduler",
|
||||
"#[allow(unused, clippy::mixed_attributes_style)]",
|
||||
)
|
||||
.client_mod_attribute(
|
||||
"sglang.grpc.scheduler",
|
||||
"#[allow(unused, clippy::mixed_attributes_style)]",
|
||||
)
|
||||
// Compile the proto file with the custom config
|
||||
.compile_protos_with_config(
|
||||
config,
|
||||
&["src/proto/sglang_scheduler.proto"],
|
||||
&["src/proto"],
|
||||
)?;
|
||||
|
||||
println!("cargo:warning=Protobuf compilation completed successfully");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
318
sgl-router/src/grpc/client.rs
Normal file
318
sgl-router/src/grpc/client.rs
Normal file
@@ -0,0 +1,318 @@
|
||||
use std::time::Duration;
|
||||
use tonic::{transport::Channel, Request};
|
||||
use tracing::debug;
|
||||
|
||||
// Include the generated protobuf code
|
||||
pub mod proto {
|
||||
tonic::include_proto!("sglang.grpc.scheduler");
|
||||
}
|
||||
|
||||
// The generated module structure depends on the package name in the .proto file
|
||||
// package sglang.grpc.scheduler; generates a nested module structure
|
||||
|
||||
/// gRPC client for SGLang scheduler
|
||||
pub struct SglangSchedulerClient {
|
||||
client: proto::sglang_scheduler_client::SglangSchedulerClient<Channel>,
|
||||
}
|
||||
|
||||
impl SglangSchedulerClient {
|
||||
/// Create a new client and connect to the scheduler
|
||||
pub async fn connect(endpoint: &str) -> Result<Self, Box<dyn std::error::Error>> {
|
||||
debug!("Connecting to SGLang scheduler at {}", endpoint);
|
||||
|
||||
let channel = Channel::from_shared(endpoint.to_string())?
|
||||
.timeout(Duration::from_secs(30))
|
||||
.connect()
|
||||
.await?;
|
||||
|
||||
let client = proto::sglang_scheduler_client::SglangSchedulerClient::new(channel);
|
||||
|
||||
Ok(Self { client })
|
||||
}
|
||||
|
||||
/// Initialize the connection
|
||||
pub async fn initialize(
|
||||
&mut self,
|
||||
client_id: String,
|
||||
) -> Result<proto::InitializeResponse, Box<dyn std::error::Error>> {
|
||||
let request = Request::new(proto::InitializeRequest {
|
||||
client_id,
|
||||
client_version: "0.1.0".to_string(),
|
||||
mode: proto::initialize_request::Mode::Regular as i32,
|
||||
});
|
||||
|
||||
let response = self.client.initialize(request).await?;
|
||||
Ok(response.into_inner())
|
||||
}
|
||||
|
||||
/// Submit a generation request (returns streaming response)
|
||||
pub async fn generate_stream(
|
||||
&mut self,
|
||||
req: proto::GenerateRequest,
|
||||
) -> Result<tonic::Streaming<proto::GenerateResponse>, Box<dyn std::error::Error>> {
|
||||
let request = Request::new(req);
|
||||
let response = self.client.generate(request).await?;
|
||||
Ok(response.into_inner())
|
||||
}
|
||||
|
||||
/// Perform health check
|
||||
pub async fn health_check(
|
||||
&mut self,
|
||||
) -> Result<proto::HealthCheckResponse, Box<dyn std::error::Error>> {
|
||||
let request = Request::new(proto::HealthCheckRequest {
|
||||
include_detailed_metrics: false,
|
||||
});
|
||||
|
||||
let response = self.client.health_check(request).await?;
|
||||
Ok(response.into_inner())
|
||||
}
|
||||
|
||||
/// Abort a request
|
||||
pub async fn abort_request(
|
||||
&mut self,
|
||||
request_id: String,
|
||||
reason: String,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let request = Request::new(proto::AbortRequest { request_id, reason });
|
||||
|
||||
self.client.abort(request).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Flush cache
|
||||
pub async fn flush_cache(
|
||||
&mut self,
|
||||
flush_all: bool,
|
||||
session_ids: &[String],
|
||||
) -> Result<proto::FlushCacheResponse, Box<dyn std::error::Error>> {
|
||||
let request = Request::new(proto::FlushCacheRequest {
|
||||
flush_all,
|
||||
session_ids: session_ids.to_vec(),
|
||||
});
|
||||
|
||||
let response = self.client.flush_cache(request).await?;
|
||||
Ok(response.into_inner())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_proto_types_compilation() {
|
||||
// Test that protobuf types can be constructed
|
||||
let init_req = proto::InitializeRequest {
|
||||
client_id: "test-client".to_string(),
|
||||
client_version: "0.1.0".to_string(),
|
||||
mode: 0,
|
||||
};
|
||||
assert_eq!(init_req.client_id, "test-client");
|
||||
assert_eq!(init_req.client_version, "0.1.0");
|
||||
assert_eq!(init_req.mode, 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_generate_request_construction() {
|
||||
let sampling_params = proto::SamplingParams {
|
||||
temperature: 0.7,
|
||||
max_new_tokens: 128,
|
||||
top_p: 0.9,
|
||||
top_k: 50,
|
||||
stop: vec!["</s>".to_string()],
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let gen_req = proto::GenerateRequest {
|
||||
request_id: "test-req-123".to_string(),
|
||||
input: Some(proto::generate_request::Input::Text(
|
||||
"Hello world".to_string(),
|
||||
)),
|
||||
sampling_params: Some(sampling_params),
|
||||
return_logprob: true,
|
||||
logprob_start_len: 0,
|
||||
top_logprobs_num: 5,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
assert_eq!(gen_req.request_id, "test-req-123");
|
||||
if let Some(proto::generate_request::Input::Text(text)) = &gen_req.input {
|
||||
assert_eq!(text, "Hello world");
|
||||
}
|
||||
assert!(gen_req.return_logprob);
|
||||
assert_eq!(gen_req.top_logprobs_num, 5);
|
||||
|
||||
let params = gen_req.sampling_params.unwrap();
|
||||
assert_eq!(params.temperature, 0.7);
|
||||
assert_eq!(params.max_new_tokens, 128);
|
||||
assert_eq!(params.stop, vec!["</s>"]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_health_check_request() {
|
||||
let health_req = proto::HealthCheckRequest {
|
||||
include_detailed_metrics: true,
|
||||
};
|
||||
assert!(health_req.include_detailed_metrics);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_abort_request_construction() {
|
||||
let abort_req = proto::AbortRequest {
|
||||
request_id: "req-456".to_string(),
|
||||
reason: "User canceled".to_string(),
|
||||
};
|
||||
assert_eq!(abort_req.request_id, "req-456");
|
||||
assert_eq!(abort_req.reason, "User canceled");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_flush_cache_request() {
|
||||
let flush_req = proto::FlushCacheRequest {
|
||||
flush_all: true,
|
||||
session_ids: vec!["session1".to_string(), "session2".to_string()],
|
||||
};
|
||||
assert!(flush_req.flush_all);
|
||||
assert_eq!(flush_req.session_ids.len(), 2);
|
||||
assert_eq!(flush_req.session_ids[0], "session1");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sampling_params_defaults() {
|
||||
let params = proto::SamplingParams::default();
|
||||
assert_eq!(params.temperature, 0.0);
|
||||
assert_eq!(params.max_new_tokens, 0);
|
||||
assert_eq!(params.top_p, 0.0);
|
||||
assert_eq!(params.top_k, 0);
|
||||
assert!(params.stop.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_multimodal_inputs() {
|
||||
let mm_inputs = proto::MultimodalInputs {
|
||||
image_urls: vec!["http://example.com/image.jpg".to_string()],
|
||||
video_urls: vec![],
|
||||
audio_urls: vec![],
|
||||
image_data: vec![],
|
||||
video_data: vec![],
|
||||
audio_data: vec![],
|
||||
modalities: vec!["image".to_string()],
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
assert_eq!(mm_inputs.image_urls.len(), 1);
|
||||
assert_eq!(mm_inputs.image_urls[0], "http://example.com/image.jpg");
|
||||
assert_eq!(mm_inputs.modalities[0], "image");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_session_params() {
|
||||
let session_params = proto::SessionParams {
|
||||
session_id: "sess-789".to_string(),
|
||||
request_id: "req-101".to_string(),
|
||||
offset: 100,
|
||||
replace: true,
|
||||
drop_previous_output: false,
|
||||
};
|
||||
|
||||
assert_eq!(session_params.session_id, "sess-789");
|
||||
assert_eq!(session_params.request_id, "req-101");
|
||||
assert_eq!(session_params.offset, 100);
|
||||
assert!(session_params.replace);
|
||||
assert!(!session_params.drop_previous_output);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_embed_request() {
|
||||
let embed_req = proto::EmbedRequest {
|
||||
request_id: "embed-req-202".to_string(),
|
||||
input: Some(proto::embed_request::Input::Text(
|
||||
"This is a test sentence for embedding".to_string(),
|
||||
)),
|
||||
log_metrics: true,
|
||||
data_parallel_rank: 0,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
assert_eq!(embed_req.request_id, "embed-req-202");
|
||||
if let Some(proto::embed_request::Input::Text(text)) = &embed_req.input {
|
||||
assert_eq!(text, "This is a test sentence for embedding");
|
||||
}
|
||||
assert!(embed_req.log_metrics);
|
||||
assert_eq!(embed_req.data_parallel_rank, 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_client_connect_invalid_endpoint() {
|
||||
// Test connecting to an invalid endpoint should return error
|
||||
let result = SglangSchedulerClient::connect("invalid://endpoint").await;
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tokenized_input() {
|
||||
let tokenized = proto::TokenizedInput {
|
||||
original_text: "Hello world".to_string(),
|
||||
input_ids: vec![1, 15043, 1917, 2],
|
||||
};
|
||||
|
||||
assert_eq!(tokenized.original_text, "Hello world");
|
||||
assert_eq!(tokenized.input_ids, vec![1, 15043, 1917, 2]);
|
||||
}
|
||||
|
||||
// Test response type construction
|
||||
#[test]
|
||||
fn test_generate_stream_chunk() {
|
||||
let chunk = proto::GenerateStreamChunk {
|
||||
token_id: 1234,
|
||||
text: " world".to_string(),
|
||||
prompt_tokens: 5,
|
||||
completion_tokens: 2,
|
||||
cached_tokens: 3,
|
||||
generation_time: 0.025,
|
||||
queue_time: 10,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
assert_eq!(chunk.token_id, 1234);
|
||||
assert_eq!(chunk.text, " world");
|
||||
assert_eq!(chunk.prompt_tokens, 5);
|
||||
assert_eq!(chunk.completion_tokens, 2);
|
||||
assert_eq!(chunk.cached_tokens, 3);
|
||||
assert_eq!(chunk.generation_time, 0.025);
|
||||
assert_eq!(chunk.queue_time, 10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_model_info() {
|
||||
let model_info = proto::ModelInfo {
|
||||
model_name: "Meta-Llama-3-8B-Instruct".to_string(),
|
||||
max_context_length: 8192,
|
||||
vocab_size: 128256,
|
||||
supports_tool_calling: true,
|
||||
supports_vision: false,
|
||||
special_tokens: vec![
|
||||
"<|begin_of_text|>".to_string(),
|
||||
"<|end_of_text|>".to_string(),
|
||||
],
|
||||
model_type: "llama".to_string(),
|
||||
num_layers: 32,
|
||||
hidden_size: 4096,
|
||||
num_attention_heads: 32,
|
||||
num_key_value_heads: 8,
|
||||
tokenizer_type: "llama".to_string(),
|
||||
eos_token_ids: vec![128001, 128009],
|
||||
pad_token_id: 128001,
|
||||
bos_token_id: 128000,
|
||||
};
|
||||
|
||||
assert_eq!(model_info.model_name, "Meta-Llama-3-8B-Instruct");
|
||||
assert_eq!(model_info.max_context_length, 8192);
|
||||
assert_eq!(model_info.vocab_size, 128256);
|
||||
assert!(model_info.supports_tool_calling);
|
||||
assert!(!model_info.supports_vision);
|
||||
assert_eq!(model_info.special_tokens.len(), 2);
|
||||
assert_eq!(model_info.num_layers, 32);
|
||||
assert_eq!(model_info.eos_token_ids, vec![128001, 128009]);
|
||||
}
|
||||
}
|
||||
8
sgl-router/src/grpc/mod.rs
Normal file
8
sgl-router/src/grpc/mod.rs
Normal file
@@ -0,0 +1,8 @@
|
||||
//! gRPC client module for communicating with SGLang scheduler
|
||||
//!
|
||||
//! This module provides a gRPC client implementation for the SGLang router.
|
||||
|
||||
pub mod client;
|
||||
|
||||
// Re-export the client
|
||||
pub use client::{proto, SglangSchedulerClient};
|
||||
@@ -3,6 +3,8 @@ pub mod config;
|
||||
pub mod logging;
|
||||
use std::collections::HashMap;
|
||||
pub mod core;
|
||||
#[cfg(feature = "grpc-client")]
|
||||
pub mod grpc;
|
||||
pub mod metrics;
|
||||
pub mod middleware;
|
||||
pub mod policies;
|
||||
|
||||
@@ -7,7 +7,7 @@ import "google/protobuf/struct.proto";
|
||||
|
||||
// Service definition for SGLang scheduler communication
|
||||
// This protocol bridges the Rust router and Python scheduler
|
||||
service SGLangScheduler {
|
||||
service SglangScheduler {
|
||||
// Initialize connection and get model info
|
||||
rpc Initialize(InitializeRequest) returns (InitializeResponse);
|
||||
|
||||
@@ -21,7 +21,7 @@ service SGLangScheduler {
|
||||
rpc HealthCheck(HealthCheckRequest) returns (HealthCheckResponse);
|
||||
|
||||
// Abort a running request
|
||||
rpc AbortRequest(AbortRequest) returns (AbortResponse);
|
||||
rpc Abort(AbortRequest) returns (AbortResponse);
|
||||
|
||||
// Flush KV cache
|
||||
rpc FlushCache(FlushCacheRequest) returns (FlushCacheResponse);
|
||||
|
||||
Reference in New Issue
Block a user