From 5e21d6aec052e371ab9763f71ab911a72bfca22e Mon Sep 17 00:00:00 2001 From: Chang Su Date: Thu, 25 Sep 2025 14:21:22 -0700 Subject: [PATCH] refactor: Move `grpc/client.rs` to `grpc_client/sglang_scheduler.rs` (#10924) --- sgl-router/src/core/worker.rs | 2 +- sgl-router/src/core/worker_builder.rs | 2 +- sgl-router/src/grpc/mod.rs | 8 - sgl-router/src/grpc_client/mod.rs | 3 + .../sglang_scheduler.rs} | 144 ++++++++++++++++- sgl-router/src/lib.rs | 2 +- sgl-router/src/routers/grpc/pd_router.rs | 2 +- sgl-router/src/routers/grpc/router.rs | 148 ++---------------- 8 files changed, 162 insertions(+), 149 deletions(-) delete mode 100644 sgl-router/src/grpc/mod.rs create mode 100644 sgl-router/src/grpc_client/mod.rs rename sgl-router/src/{grpc/client.rs => grpc_client/sglang_scheduler.rs} (57%) diff --git a/sgl-router/src/core/worker.rs b/sgl-router/src/core/worker.rs index 08903ba72..722510fc2 100644 --- a/sgl-router/src/core/worker.rs +++ b/sgl-router/src/core/worker.rs @@ -1,7 +1,7 @@ use super::{CircuitBreaker, WorkerError, WorkerResult}; use crate::core::CircuitState; use crate::core::{BasicWorkerBuilder, DPAwareWorkerBuilder}; -use crate::grpc::SglangSchedulerClient; +use crate::grpc_client::SglangSchedulerClient; use crate::metrics::RouterMetrics; use async_trait::async_trait; use futures; diff --git a/sgl-router/src/core/worker_builder.rs b/sgl-router/src/core/worker_builder.rs index 0011fda3a..9dd03b30a 100644 --- a/sgl-router/src/core/worker_builder.rs +++ b/sgl-router/src/core/worker_builder.rs @@ -2,7 +2,7 @@ use super::circuit_breaker::{CircuitBreaker, CircuitBreakerConfig}; use super::worker::{ BasicWorker, ConnectionMode, DPAwareWorker, HealthConfig, WorkerMetadata, WorkerType, }; -use crate::grpc::client::SglangSchedulerClient; +use crate::grpc_client::SglangSchedulerClient; use std::collections::HashMap; /// Builder for creating BasicWorker instances with fluent API diff --git a/sgl-router/src/grpc/mod.rs b/sgl-router/src/grpc/mod.rs deleted file mode 100644 index 331a6a538..000000000 --- a/sgl-router/src/grpc/mod.rs +++ /dev/null @@ -1,8 +0,0 @@ -//! gRPC client module for communicating with SGLang scheduler -//! -//! This module provides a gRPC client implementation for the SGLang router. - -pub mod client; - -// Re-export the client -pub use client::{proto, SglangSchedulerClient}; diff --git a/sgl-router/src/grpc_client/mod.rs b/sgl-router/src/grpc_client/mod.rs new file mode 100644 index 000000000..cb5bca140 --- /dev/null +++ b/sgl-router/src/grpc_client/mod.rs @@ -0,0 +1,3 @@ +pub mod sglang_scheduler; + +pub use sglang_scheduler::{proto, SglangSchedulerClient}; diff --git a/sgl-router/src/grpc/client.rs b/sgl-router/src/grpc_client/sglang_scheduler.rs similarity index 57% rename from sgl-router/src/grpc/client.rs rename to sgl-router/src/grpc_client/sglang_scheduler.rs index b68224b3c..0b87f85b3 100644 --- a/sgl-router/src/grpc/client.rs +++ b/sgl-router/src/grpc_client/sglang_scheduler.rs @@ -2,6 +2,8 @@ use std::time::Duration; use tonic::{transport::Channel, Request}; use tracing::debug; +use crate::protocols::spec::{ChatCompletionRequest, ResponseFormat}; + // Include the generated protobuf code pub mod proto { tonic::include_proto!("sglang.grpc.scheduler"); @@ -75,6 +77,142 @@ impl SglangSchedulerClient { self.client.abort(request).await?; Ok(()) } + + /// Build a single SGLang GenerateRequest from OpenAI ChatCompletionRequest + pub fn build_generate_request( + &self, + request_id: String, + body: &ChatCompletionRequest, + processed_text: String, + token_ids: Vec, + multimodal_inputs: Option, + tool_call_constraint: Option<(String, String)>, // (constraint_type, constraint_value) + ) -> Result { + // Build sampling params + let sampling_params = self.build_grpc_sampling_params(body, tool_call_constraint)?; + + let grpc_request = proto::GenerateRequest { + request_id, + tokenized: Some(proto::TokenizedInput { + original_text: processed_text, + input_ids: token_ids, + }), + mm_inputs: multimodal_inputs, + sampling_params: Some(sampling_params), + return_logprob: body.logprobs, + logprob_start_len: -1, + top_logprobs_num: body.top_logprobs.unwrap_or(0) as i32, + return_hidden_states: body.return_hidden_states, + ..Default::default() + }; + + Ok(grpc_request) + } + + /// Build gRPC SamplingParams from OpenAI request + fn build_grpc_sampling_params( + &self, + request: &ChatCompletionRequest, + tool_call_constraint: Option<(String, String)>, + ) -> Result { + let stop_sequences = self.extract_stop_strings(request); + + // Handle max tokens: prefer max_completion_tokens (new) over max_tokens (deprecated) + // If neither is specified, use None to let the backend decide the default + #[allow(deprecated)] + let max_new_tokens = request + .max_completion_tokens + .or(request.max_tokens) + .map(|v| v as i32); + + // Handle skip_special_tokens: set to false if tools are present and tool_choice is not "none" + let skip_special_tokens = if request.tools.is_some() { + match &request.tool_choice { + Some(crate::protocols::spec::ToolChoice::Value( + crate::protocols::spec::ToolChoiceValue::None, + )) => request.skip_special_tokens, + Some(_) => false, // tool_choice is not "none" + None => false, // TODO: this assumes tool_choice defaults to "auto" when tools present + } + } else { + request.skip_special_tokens + }; + + #[allow(deprecated)] + Ok(proto::SamplingParams { + temperature: request.temperature.unwrap_or(1.0), + top_p: request.top_p.unwrap_or(1.0), + top_k: request.top_k.unwrap_or(-1), + min_p: request.min_p.unwrap_or(0.0), + frequency_penalty: request.frequency_penalty.unwrap_or(0.0), + presence_penalty: request.presence_penalty.unwrap_or(0.0), + repetition_penalty: request.repetition_penalty.unwrap_or(1.0), + max_new_tokens, + stop: stop_sequences, + stop_token_ids: request.stop_token_ids.clone().unwrap_or_default(), + skip_special_tokens, + n: request.n.unwrap_or(1) as i32, + constraint: self.build_constraint(request, tool_call_constraint)?, + ..Default::default() + }) + } + + /// Extract stop strings from request + fn extract_stop_strings(&self, request: &ChatCompletionRequest) -> Vec { + match &request.stop { + Some(crate::protocols::spec::StringOrArray::String(s)) => vec![s.clone()], + Some(crate::protocols::spec::StringOrArray::Array(arr)) => arr.clone(), + None => vec![], + } + } + + /// Build constraint for structured generation + fn build_constraint( + &self, + request: &ChatCompletionRequest, + tool_call_constraint: Option<(String, String)>, + ) -> Result, String> { + let mut constraints = Vec::new(); + + if let Some(ResponseFormat::JsonSchema { json_schema }) = &request.response_format { + let schema_str = serde_json::to_string(&json_schema.schema) + .map_err(|e| format!("Failed to serialize JSON schema: {}", e))?; + constraints.push(proto::sampling_params::Constraint::JsonSchema(schema_str)); + } + + if let Some(ebnf) = &request.ebnf { + constraints.push(proto::sampling_params::Constraint::EbnfGrammar( + ebnf.clone(), + )); + } + + if let Some(regex) = &request.regex { + constraints.push(proto::sampling_params::Constraint::Regex(regex.clone())); + } + + // Handle tool call constraint + if let Some((constraint_type, constraint_value)) = tool_call_constraint { + if !constraints.is_empty() { + return Err("Constrained decoding is not compatible with tool calls.".to_string()); + } + let tool_constraint = match constraint_type.as_str() { + "structural_tag" => { + proto::sampling_params::Constraint::StructuralTag(constraint_value) + } + "json_schema" => proto::sampling_params::Constraint::JsonSchema(constraint_value), + "ebnf" => proto::sampling_params::Constraint::EbnfGrammar(constraint_value), + "regex" => proto::sampling_params::Constraint::Regex(constraint_value), + _ => return Err(format!("Unknown constraint type: {}", constraint_type)), + }; + constraints.push(tool_constraint); + } + + match constraints.len() { + 0 => Ok(None), + 1 => Ok(constraints.pop()), + _ => Err("Multiple constraints are not allowed.".to_string()), + } + } } #[cfg(test)] @@ -230,22 +368,16 @@ mod tests { fn test_generate_stream_chunk() { let chunk = proto::GenerateStreamChunk { token_id: 1234, - text: " world".to_string(), prompt_tokens: 5, completion_tokens: 2, cached_tokens: 3, - generation_time: 0.025, - queue_time: 10, ..Default::default() }; assert_eq!(chunk.token_id, 1234); - assert_eq!(chunk.text, " world"); assert_eq!(chunk.prompt_tokens, 5); assert_eq!(chunk.completion_tokens, 2); assert_eq!(chunk.cached_tokens, 3); - assert_eq!(chunk.generation_time, 0.025); - assert_eq!(chunk.queue_time, 10); } // TODO: ModelInfo not in current proto - skip test diff --git a/sgl-router/src/lib.rs b/sgl-router/src/lib.rs index 36c6a02d7..a1d2cabc4 100644 --- a/sgl-router/src/lib.rs +++ b/sgl-router/src/lib.rs @@ -6,7 +6,7 @@ use std::collections::HashMap; pub mod core; pub mod data_connector; #[cfg(feature = "grpc-client")] -pub mod grpc; +pub mod grpc_client; pub mod mcp; pub mod metrics; pub mod middleware; diff --git a/sgl-router/src/routers/grpc/pd_router.rs b/sgl-router/src/routers/grpc/pd_router.rs index a60744518..d28560dbf 100644 --- a/sgl-router/src/routers/grpc/pd_router.rs +++ b/sgl-router/src/routers/grpc/pd_router.rs @@ -4,7 +4,7 @@ use crate::config::types::RetryConfig; use crate::core::{ BasicWorkerBuilder, CircuitBreakerConfig, HealthConfig, WorkerRegistry, WorkerType, }; -use crate::grpc::SglangSchedulerClient; +use crate::grpc_client::SglangSchedulerClient; use crate::metrics::RouterMetrics; use crate::policies::{LoadBalancingPolicy, PolicyRegistry}; use crate::reasoning_parser::ParserFactory; diff --git a/sgl-router/src/routers/grpc/router.rs b/sgl-router/src/routers/grpc/router.rs index c3a83389e..f4f4337b3 100644 --- a/sgl-router/src/routers/grpc/router.rs +++ b/sgl-router/src/routers/grpc/router.rs @@ -17,10 +17,10 @@ use crate::config::types::RetryConfig; use crate::core::{ BasicWorkerBuilder, CircuitBreakerConfig, HealthConfig, WorkerRegistry, WorkerType, }; -use crate::grpc::{proto, SglangSchedulerClient}; +use crate::grpc_client::{proto, SglangSchedulerClient}; use crate::metrics::RouterMetrics; use crate::policies::{LoadBalancingPolicy, PolicyRegistry}; -use crate::protocols::spec::{ChatCompletionRequest, ResponseFormat, StringOrArray}; +use crate::protocols::spec::{ChatCompletionRequest, StringOrArray}; use crate::reasoning_parser::ParserFactory; use crate::routers::RouterTrait; use crate::tokenizer::traits::Tokenizer; @@ -247,40 +247,31 @@ impl GrpcRouter { None }; - // Step 6: Build SamplingParams for gRPC - let sampling_params = match self.build_grpc_sampling_params(body, tool_call_constraint) { - Ok(params) => params, + // Step 6: Build the base gRPC request + let request_id = format!("chatcmpl-{}", Uuid::new_v4()); + let base_request = match client.build_generate_request( + request_id, + body, + processed_messages.text.clone(), + token_ids.into_iter().map(|id| id as i32).collect(), + processed_messages.multimodal_inputs, + tool_call_constraint, // Pass the full tuple (type, value) + ) { + Ok(request) => request, Err(e) => { - error!("Failed to build sampling parameters: {}", e); + error!("Failed to build gRPC request: {}", e); return ( StatusCode::BAD_REQUEST, - format!("Invalid sampling parameters: {}", e), + format!("Invalid request parameters: {}", e), ) .into_response(); } }; - // Step 7: Create GenerateRequest - let grpc_request = proto::GenerateRequest { - request_id: format!("chatcmpl-{}", Uuid::new_v4()), - tokenized: Some(proto::TokenizedInput { - original_text: processed_messages.text.clone(), - input_ids: token_ids.into_iter().map(|id| id as i32).collect(), - }), - mm_inputs: processed_messages.multimodal_inputs, - sampling_params: Some(sampling_params), - return_logprob: body.logprobs, - logprob_start_len: -1, - top_logprobs_num: body.top_logprobs.unwrap_or(0) as i32, - return_hidden_states: body.return_hidden_states, - ..Default::default() - }; - - // Step 8: Handle streaming vs non-streaming if body.stream { - self.handle_streaming_chat(client, grpc_request, body).await + self.handle_streaming_chat(client, base_request, body).await } else { - self.handle_non_streaming_chat(client, grpc_request, body) + self.handle_non_streaming_chat(client, base_request, body) .await } } @@ -547,111 +538,6 @@ impl GrpcRouter { Ok(()) } - /// Build gRPC SamplingParams from OpenAI request - fn build_grpc_sampling_params( - &self, - request: &ChatCompletionRequest, - tool_call_constraint: Option<(String, String)>, - ) -> Result { - let stop_sequences = self.extract_stop_strings(request); - - // Handle max tokens: prefer max_completion_tokens (new) over max_tokens (deprecated) - // If neither is specified, use None to let the backend decide the default - #[allow(deprecated)] - let max_new_tokens = request - .max_completion_tokens - .or(request.max_tokens) - .map(|v| v as i32); - - // Handle skip_special_tokens: set to false if tools are present and tool_choice is not "none" - let skip_special_tokens = if request.tools.is_some() { - match &request.tool_choice { - Some(crate::protocols::spec::ToolChoice::Value( - crate::protocols::spec::ToolChoiceValue::None, - )) => request.skip_special_tokens, - Some(_) => false, // tool_choice is not "none" - None => false, // TODO: this assumes tool_choice defaults to "auto" when tools present - } - } else { - request.skip_special_tokens - }; - - #[allow(deprecated)] - Ok(proto::SamplingParams { - temperature: request.temperature.unwrap_or(1.0), - top_p: request.top_p.unwrap_or(1.0), - top_k: request.top_k.unwrap_or(-1), - min_p: request.min_p.unwrap_or(0.0), - frequency_penalty: request.frequency_penalty.unwrap_or(0.0), - presence_penalty: request.presence_penalty.unwrap_or(0.0), - repetition_penalty: request.repetition_penalty.unwrap_or(1.0), - max_new_tokens, - stop: stop_sequences, - stop_token_ids: request.stop_token_ids.clone().unwrap_or_default(), - skip_special_tokens, - n: request.n.unwrap_or(1) as i32, - constraint: self.build_constraint(request, tool_call_constraint)?, - ..Default::default() - }) - } - - /// Extract stop strings from request - fn extract_stop_strings(&self, request: &ChatCompletionRequest) -> Vec { - match &request.stop { - Some(StringOrArray::String(s)) => vec![s.clone()], - Some(StringOrArray::Array(arr)) => arr.clone(), - None => vec![], - } - } - - /// Build constraint for structured generation - fn build_constraint( - &self, - request: &ChatCompletionRequest, - tool_call_constraint: Option<(String, String)>, - ) -> Result, String> { - let mut constraints = Vec::new(); - - if let Some(ResponseFormat::JsonSchema { json_schema }) = &request.response_format { - let schema_str = serde_json::to_string(&json_schema.schema) - .map_err(|e| format!("Failed to serialize JSON schema: {}", e))?; - constraints.push(proto::sampling_params::Constraint::JsonSchema(schema_str)); - } - - if let Some(ebnf) = &request.ebnf { - constraints.push(proto::sampling_params::Constraint::EbnfGrammar( - ebnf.clone(), - )); - } - - if let Some(regex) = &request.regex { - constraints.push(proto::sampling_params::Constraint::Regex(regex.clone())); - } - - // Handle tool call constraint - if let Some((constraint_type, constraint_value)) = tool_call_constraint { - if !constraints.is_empty() { - return Err("Constrained decoding is not compatible with tool calls.".to_string()); - } - let tool_constraint = match constraint_type.as_str() { - "structural_tag" => { - proto::sampling_params::Constraint::StructuralTag(constraint_value) - } - "json_schema" => proto::sampling_params::Constraint::JsonSchema(constraint_value), - "ebnf" => proto::sampling_params::Constraint::EbnfGrammar(constraint_value), - "regex" => proto::sampling_params::Constraint::Regex(constraint_value), - _ => return Err(format!("Unknown constraint type: {}", constraint_type)), - }; - constraints.push(tool_constraint); - } - - match constraints.len() { - 0 => Ok(None), - 1 => Ok(constraints.pop()), - _ => Err("Multiple constraints are not allowed.".to_string()), - } - } - /// Generate tool constraints for structured generation fn generate_tool_constraints( &self,