diff --git a/.github/workflows/release-pypi-router.yml b/.github/workflows/release-pypi-router.yml index 948b3f584..5653cd1d1 100644 --- a/.github/workflows/release-pypi-router.yml +++ b/.github/workflows/release-pypi-router.yml @@ -47,7 +47,7 @@ jobs: env: CIBW_BUILD: "cp38-manylinux_x86_64 cp39-manylinux_x86_64 cp310-manylinux_x86_64 cp311-manylinux_x86_64 cp312-manylinux_x86_64" CIBW_BEFORE_ALL: | - yum update && yum install -y openssl-devel && curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y + yum update && yum install -y openssl-devel protobuf-compiler && curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y CIBW_ENVIRONMENT: "PATH=$HOME/.cargo/bin:$PATH" - name: List built packages diff --git a/docker/Dockerfile.router b/docker/Dockerfile.router index 07633e502..ded98bb8a 100644 --- a/docker/Dockerfile.router +++ b/docker/Dockerfile.router @@ -39,13 +39,13 @@ ENV PATH="/root/.cargo/bin:${PATH}" # install dependencies RUN apt update -y \ - && apt install -y git build-essential libssl-dev pkg-config \ + && apt install -y git build-essential libssl-dev pkg-config protobuf-compiler \ && rm -rf /var/lib/apt/lists/* \ && apt clean # install rustup from rustup.rs RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y \ - && rustc --version && cargo --version + && rustc --version && cargo --version && protoc --version # pull the github repository RUN cd /opt \ diff --git a/scripts/ci/ci_install_rust.sh b/scripts/ci/ci_install_rust.sh index 519155dfb..ac042fc9a 100755 --- a/scripts/ci/ci_install_rust.sh +++ b/scripts/ci/ci_install_rust.sh @@ -4,10 +4,10 @@ set -euxo pipefail # Check if sudo is available if command -v sudo >/dev/null 2>&1; then sudo apt-get update - sudo apt-get install -y libssl-dev pkg-config + sudo apt-get install -y libssl-dev pkg-config protobuf-compiler else apt-get update - apt-get install -y libssl-dev pkg-config + apt-get install -y libssl-dev pkg-config protobuf-compiler fi # Install rustup (Rust installer and version manager) @@ -21,3 +21,4 @@ source $HOME/.cargo/env # Verify installation rustc --version cargo --version +protoc --version diff --git a/sgl-router/Cargo.toml b/sgl-router/Cargo.toml index e0defacdf..1b20f3cba 100644 --- a/sgl-router/Cargo.toml +++ b/sgl-router/Cargo.toml @@ -4,9 +4,11 @@ version = "0.0.0" edition = "2021" [features] -default = ["huggingface"] +default = ["huggingface", "grpc-client"] huggingface = ["tokenizers"] tiktoken = ["tiktoken-rs"] +grpc-client = [] +grpc-server = [] [lib] name = "sglang_router_rs" @@ -52,6 +54,18 @@ anyhow = "1.0" tokenizers = { version = "0.21.4", optional = true } tiktoken-rs = { version = "0.5", optional = true } +# gRPC and Protobuf dependencies +tonic = { version = "0.12", features = ["tls", "gzip", "transport"] } +prost = "0.13" +prost-types = "0.13" +deadpool = { version = "0.12", features = ["managed", "rt_tokio_1"] } +backoff = { version = "0.4", features = ["tokio"] } +strum = { version = "0.26", features = ["derive"] } + +[build-dependencies] +tonic-build = "0.12" +prost-build = "0.13" + [dev-dependencies] criterion = { version = "0.5", features = ["html_reports"] } tower = { version = "0.5", features = ["util"] } diff --git a/sgl-router/MANIFEST.in b/sgl-router/MANIFEST.in index e1d6e7a90..4baa6c84f 100644 --- a/sgl-router/MANIFEST.in +++ b/sgl-router/MANIFEST.in @@ -1,3 +1,5 @@ # Must include: include Cargo.toml # Rust project configuration +include build.rs # Build script for protobuf generation recursive-include src *.rs # Rust source files +recursive-include src/proto *.proto # Protobuf definitions diff --git a/sgl-router/build.rs b/sgl-router/build.rs new file mode 100644 index 000000000..90b3c6101 --- /dev/null +++ b/sgl-router/build.rs @@ -0,0 +1,35 @@ +fn main() -> Result<(), Box> { + // Only regenerate if the proto file changes + println!("cargo:rerun-if-changed=src/proto/sglang_scheduler.proto"); + + // Configure protobuf compilation with custom settings + let config = prost_build::Config::new(); + + // Skip serde for types that use prost_types::Struct + // These cause conflicts and we don't need serde for all generated types + + // Configure tonic-build for gRPC code generation + tonic_build::configure() + // Generate both client and server code + .build_server(true) + .build_client(true) + // Add a module-level attribute for documentation and clippy warnings + .server_mod_attribute( + "sglang.grpc.scheduler", + "#[allow(unused, clippy::mixed_attributes_style)]", + ) + .client_mod_attribute( + "sglang.grpc.scheduler", + "#[allow(unused, clippy::mixed_attributes_style)]", + ) + // Compile the proto file with the custom config + .compile_protos_with_config( + config, + &["src/proto/sglang_scheduler.proto"], + &["src/proto"], + )?; + + println!("cargo:warning=Protobuf compilation completed successfully"); + + Ok(()) +} diff --git a/sgl-router/src/grpc/client.rs b/sgl-router/src/grpc/client.rs new file mode 100644 index 000000000..f31227bb1 --- /dev/null +++ b/sgl-router/src/grpc/client.rs @@ -0,0 +1,318 @@ +use std::time::Duration; +use tonic::{transport::Channel, Request}; +use tracing::debug; + +// Include the generated protobuf code +pub mod proto { + tonic::include_proto!("sglang.grpc.scheduler"); +} + +// The generated module structure depends on the package name in the .proto file +// package sglang.grpc.scheduler; generates a nested module structure + +/// gRPC client for SGLang scheduler +pub struct SglangSchedulerClient { + client: proto::sglang_scheduler_client::SglangSchedulerClient, +} + +impl SglangSchedulerClient { + /// Create a new client and connect to the scheduler + pub async fn connect(endpoint: &str) -> Result> { + debug!("Connecting to SGLang scheduler at {}", endpoint); + + let channel = Channel::from_shared(endpoint.to_string())? + .timeout(Duration::from_secs(30)) + .connect() + .await?; + + let client = proto::sglang_scheduler_client::SglangSchedulerClient::new(channel); + + Ok(Self { client }) + } + + /// Initialize the connection + pub async fn initialize( + &mut self, + client_id: String, + ) -> Result> { + let request = Request::new(proto::InitializeRequest { + client_id, + client_version: "0.1.0".to_string(), + mode: proto::initialize_request::Mode::Regular as i32, + }); + + let response = self.client.initialize(request).await?; + Ok(response.into_inner()) + } + + /// Submit a generation request (returns streaming response) + pub async fn generate_stream( + &mut self, + req: proto::GenerateRequest, + ) -> Result, Box> { + let request = Request::new(req); + let response = self.client.generate(request).await?; + Ok(response.into_inner()) + } + + /// Perform health check + pub async fn health_check( + &mut self, + ) -> Result> { + let request = Request::new(proto::HealthCheckRequest { + include_detailed_metrics: false, + }); + + let response = self.client.health_check(request).await?; + Ok(response.into_inner()) + } + + /// Abort a request + pub async fn abort_request( + &mut self, + request_id: String, + reason: String, + ) -> Result<(), Box> { + let request = Request::new(proto::AbortRequest { request_id, reason }); + + self.client.abort(request).await?; + Ok(()) + } + + /// Flush cache + pub async fn flush_cache( + &mut self, + flush_all: bool, + session_ids: &[String], + ) -> Result> { + let request = Request::new(proto::FlushCacheRequest { + flush_all, + session_ids: session_ids.to_vec(), + }); + + let response = self.client.flush_cache(request).await?; + Ok(response.into_inner()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_proto_types_compilation() { + // Test that protobuf types can be constructed + let init_req = proto::InitializeRequest { + client_id: "test-client".to_string(), + client_version: "0.1.0".to_string(), + mode: 0, + }; + assert_eq!(init_req.client_id, "test-client"); + assert_eq!(init_req.client_version, "0.1.0"); + assert_eq!(init_req.mode, 0); + } + + #[test] + fn test_generate_request_construction() { + let sampling_params = proto::SamplingParams { + temperature: 0.7, + max_new_tokens: 128, + top_p: 0.9, + top_k: 50, + stop: vec!["".to_string()], + ..Default::default() + }; + + let gen_req = proto::GenerateRequest { + request_id: "test-req-123".to_string(), + input: Some(proto::generate_request::Input::Text( + "Hello world".to_string(), + )), + sampling_params: Some(sampling_params), + return_logprob: true, + logprob_start_len: 0, + top_logprobs_num: 5, + ..Default::default() + }; + + assert_eq!(gen_req.request_id, "test-req-123"); + if let Some(proto::generate_request::Input::Text(text)) = &gen_req.input { + assert_eq!(text, "Hello world"); + } + assert!(gen_req.return_logprob); + assert_eq!(gen_req.top_logprobs_num, 5); + + let params = gen_req.sampling_params.unwrap(); + assert_eq!(params.temperature, 0.7); + assert_eq!(params.max_new_tokens, 128); + assert_eq!(params.stop, vec![""]); + } + + #[test] + fn test_health_check_request() { + let health_req = proto::HealthCheckRequest { + include_detailed_metrics: true, + }; + assert!(health_req.include_detailed_metrics); + } + + #[test] + fn test_abort_request_construction() { + let abort_req = proto::AbortRequest { + request_id: "req-456".to_string(), + reason: "User canceled".to_string(), + }; + assert_eq!(abort_req.request_id, "req-456"); + assert_eq!(abort_req.reason, "User canceled"); + } + + #[test] + fn test_flush_cache_request() { + let flush_req = proto::FlushCacheRequest { + flush_all: true, + session_ids: vec!["session1".to_string(), "session2".to_string()], + }; + assert!(flush_req.flush_all); + assert_eq!(flush_req.session_ids.len(), 2); + assert_eq!(flush_req.session_ids[0], "session1"); + } + + #[test] + fn test_sampling_params_defaults() { + let params = proto::SamplingParams::default(); + assert_eq!(params.temperature, 0.0); + assert_eq!(params.max_new_tokens, 0); + assert_eq!(params.top_p, 0.0); + assert_eq!(params.top_k, 0); + assert!(params.stop.is_empty()); + } + + #[test] + fn test_multimodal_inputs() { + let mm_inputs = proto::MultimodalInputs { + image_urls: vec!["http://example.com/image.jpg".to_string()], + video_urls: vec![], + audio_urls: vec![], + image_data: vec![], + video_data: vec![], + audio_data: vec![], + modalities: vec!["image".to_string()], + ..Default::default() + }; + + assert_eq!(mm_inputs.image_urls.len(), 1); + assert_eq!(mm_inputs.image_urls[0], "http://example.com/image.jpg"); + assert_eq!(mm_inputs.modalities[0], "image"); + } + + #[test] + fn test_session_params() { + let session_params = proto::SessionParams { + session_id: "sess-789".to_string(), + request_id: "req-101".to_string(), + offset: 100, + replace: true, + drop_previous_output: false, + }; + + assert_eq!(session_params.session_id, "sess-789"); + assert_eq!(session_params.request_id, "req-101"); + assert_eq!(session_params.offset, 100); + assert!(session_params.replace); + assert!(!session_params.drop_previous_output); + } + + #[test] + fn test_embed_request() { + let embed_req = proto::EmbedRequest { + request_id: "embed-req-202".to_string(), + input: Some(proto::embed_request::Input::Text( + "This is a test sentence for embedding".to_string(), + )), + log_metrics: true, + data_parallel_rank: 0, + ..Default::default() + }; + + assert_eq!(embed_req.request_id, "embed-req-202"); + if let Some(proto::embed_request::Input::Text(text)) = &embed_req.input { + assert_eq!(text, "This is a test sentence for embedding"); + } + assert!(embed_req.log_metrics); + assert_eq!(embed_req.data_parallel_rank, 0); + } + + #[tokio::test] + async fn test_client_connect_invalid_endpoint() { + // Test connecting to an invalid endpoint should return error + let result = SglangSchedulerClient::connect("invalid://endpoint").await; + assert!(result.is_err()); + } + + #[test] + fn test_tokenized_input() { + let tokenized = proto::TokenizedInput { + original_text: "Hello world".to_string(), + input_ids: vec![1, 15043, 1917, 2], + }; + + assert_eq!(tokenized.original_text, "Hello world"); + assert_eq!(tokenized.input_ids, vec![1, 15043, 1917, 2]); + } + + // Test response type construction + #[test] + fn test_generate_stream_chunk() { + let chunk = proto::GenerateStreamChunk { + token_id: 1234, + text: " world".to_string(), + prompt_tokens: 5, + completion_tokens: 2, + cached_tokens: 3, + generation_time: 0.025, + queue_time: 10, + ..Default::default() + }; + + assert_eq!(chunk.token_id, 1234); + assert_eq!(chunk.text, " world"); + assert_eq!(chunk.prompt_tokens, 5); + assert_eq!(chunk.completion_tokens, 2); + assert_eq!(chunk.cached_tokens, 3); + assert_eq!(chunk.generation_time, 0.025); + assert_eq!(chunk.queue_time, 10); + } + + #[test] + fn test_model_info() { + let model_info = proto::ModelInfo { + model_name: "Meta-Llama-3-8B-Instruct".to_string(), + max_context_length: 8192, + vocab_size: 128256, + supports_tool_calling: true, + supports_vision: false, + special_tokens: vec![ + "<|begin_of_text|>".to_string(), + "<|end_of_text|>".to_string(), + ], + model_type: "llama".to_string(), + num_layers: 32, + hidden_size: 4096, + num_attention_heads: 32, + num_key_value_heads: 8, + tokenizer_type: "llama".to_string(), + eos_token_ids: vec![128001, 128009], + pad_token_id: 128001, + bos_token_id: 128000, + }; + + assert_eq!(model_info.model_name, "Meta-Llama-3-8B-Instruct"); + assert_eq!(model_info.max_context_length, 8192); + assert_eq!(model_info.vocab_size, 128256); + assert!(model_info.supports_tool_calling); + assert!(!model_info.supports_vision); + assert_eq!(model_info.special_tokens.len(), 2); + assert_eq!(model_info.num_layers, 32); + assert_eq!(model_info.eos_token_ids, vec![128001, 128009]); + } +} diff --git a/sgl-router/src/grpc/mod.rs b/sgl-router/src/grpc/mod.rs new file mode 100644 index 000000000..331a6a538 --- /dev/null +++ b/sgl-router/src/grpc/mod.rs @@ -0,0 +1,8 @@ +//! gRPC client module for communicating with SGLang scheduler +//! +//! This module provides a gRPC client implementation for the SGLang router. + +pub mod client; + +// Re-export the client +pub use client::{proto, SglangSchedulerClient}; diff --git a/sgl-router/src/lib.rs b/sgl-router/src/lib.rs index ec29a1740..e41942c14 100644 --- a/sgl-router/src/lib.rs +++ b/sgl-router/src/lib.rs @@ -3,6 +3,8 @@ pub mod config; pub mod logging; use std::collections::HashMap; pub mod core; +#[cfg(feature = "grpc-client")] +pub mod grpc; pub mod metrics; pub mod middleware; pub mod policies; diff --git a/sgl-router/src/proto/sglang_scheduler.proto b/sgl-router/src/proto/sglang_scheduler.proto index be8bb09eb..1ea2855a4 100644 --- a/sgl-router/src/proto/sglang_scheduler.proto +++ b/sgl-router/src/proto/sglang_scheduler.proto @@ -7,7 +7,7 @@ import "google/protobuf/struct.proto"; // Service definition for SGLang scheduler communication // This protocol bridges the Rust router and Python scheduler -service SGLangScheduler { +service SglangScheduler { // Initialize connection and get model info rpc Initialize(InitializeRequest) returns (InitializeResponse); @@ -21,7 +21,7 @@ service SGLangScheduler { rpc HealthCheck(HealthCheckRequest) returns (HealthCheckResponse); // Abort a running request - rpc AbortRequest(AbortRequest) returns (AbortResponse); + rpc Abort(AbortRequest) returns (AbortResponse); // Flush KV cache rpc FlushCache(FlushCacheRequest) returns (FlushCacheResponse);