diff --git a/sgl-router/build.rs b/sgl-router/build.rs index 90b3c6101..48e7e324c 100644 --- a/sgl-router/build.rs +++ b/sgl-router/build.rs @@ -13,6 +13,8 @@ fn main() -> Result<(), Box> { // Generate both client and server code .build_server(true) .build_client(true) + // Add protoc arguments for proto3 optional support + .protoc_arg("--experimental_allow_proto3_optional") // Add a module-level attribute for documentation and clippy warnings .server_mod_attribute( "sglang.grpc.scheduler", diff --git a/sgl-router/src/grpc/client.rs b/sgl-router/src/grpc/client.rs index efd141a0b..b68224b3c 100644 --- a/sgl-router/src/grpc/client.rs +++ b/sgl-router/src/grpc/client.rs @@ -97,7 +97,7 @@ mod tests { fn test_generate_request_construction() { let sampling_params = proto::SamplingParams { temperature: 0.7, - max_new_tokens: 128, + max_new_tokens: Some(128), top_p: 0.9, top_k: 50, stop: vec!["".to_string()], @@ -126,7 +126,7 @@ mod tests { let params = gen_req.sampling_params.unwrap(); assert_eq!(params.temperature, 0.7); - assert_eq!(params.max_new_tokens, 128); + assert_eq!(params.max_new_tokens, Some(128)); assert_eq!(params.stop, vec![""]); } @@ -155,7 +155,7 @@ mod tests { fn test_sampling_params_defaults() { let params = proto::SamplingParams::default(); assert_eq!(params.temperature, 0.0); - assert_eq!(params.max_new_tokens, 0); + assert_eq!(params.max_new_tokens, None); assert_eq!(params.top_p, 0.0); assert_eq!(params.top_k, 0); assert!(params.stop.is_empty()); diff --git a/sgl-router/src/proto/sglang_scheduler.proto b/sgl-router/src/proto/sglang_scheduler.proto index e4c87925e..e2adc7863 100644 --- a/sgl-router/src/proto/sglang_scheduler.proto +++ b/sgl-router/src/proto/sglang_scheduler.proto @@ -36,7 +36,7 @@ message SamplingParams { float presence_penalty = 6; float repetition_penalty = 7; - int32 max_new_tokens = 8; + optional int32 max_new_tokens = 8; repeated string stop = 9; repeated int32 stop_token_ids = 10; bool skip_special_tokens = 11; diff --git a/sgl-router/src/routers/grpc/router.rs b/sgl-router/src/routers/grpc/router.rs index 4898fb451..e0efd3e8c 100644 --- a/sgl-router/src/routers/grpc/router.rs +++ b/sgl-router/src/routers/grpc/router.rs @@ -4,12 +4,16 @@ use crate::config::types::RetryConfig; use crate::core::{ BasicWorkerBuilder, CircuitBreakerConfig, HealthConfig, WorkerRegistry, WorkerType, }; -use crate::grpc::SglangSchedulerClient; +use crate::grpc::{proto, SglangSchedulerClient}; use crate::metrics::RouterMetrics; use crate::policies::{LoadBalancingPolicy, PolicyRegistry}; +use crate::protocols::spec::{ + ChatCompletionRequest, ChatMessage, ContentPart, ResponseFormat, StringOrArray, + UserMessageContent, +}; use crate::reasoning_parser::ParserFactory; use crate::routers::RouterTrait; -use crate::tokenizer::traits::Tokenizer; +use crate::tokenizer::{chat_template::ChatMessage as TokenizerChatMessage, traits::Tokenizer}; use crate::tool_parser::ParserRegistry; use async_trait::async_trait; use axum::{ @@ -21,7 +25,16 @@ use axum::{ use std::collections::HashMap; use std::sync::Arc; use std::time::Duration; -use tracing::{info, warn}; +use tracing::{debug, error, info, warn}; +use uuid::Uuid; + +// Data structures for processing +#[derive(Debug)] +pub struct ProcessedMessages { + pub text: String, + pub multimodal_inputs: Option, + pub stop_sequences: Option, +} /// gRPC router implementation for SGLang #[allow(dead_code)] // Fields will be used once implementation is complete @@ -161,6 +174,345 @@ impl GrpcRouter { circuit_breaker_config: core_cb_config, }) } + + // ============ Chat Implementation ============ + + /// Main route_chat implementation + async fn route_chat_impl( + &self, + _headers: Option<&HeaderMap>, + body: &ChatCompletionRequest, + model_id: Option<&str>, + ) -> Response { + debug!( + "Processing chat completion request for model: {:?}", + model_id + ); + + // Step 1: Select worker (fail fast if no workers available) + let worker = match self.select_worker_for_request(model_id, None) { + Some(w) => w, + None => { + warn!("No available workers for model: {:?}", model_id); + return (StatusCode::SERVICE_UNAVAILABLE, "No available workers").into_response(); + } + }; + + debug!("Selected worker: {}", worker.url()); + + // Step 2: Get gRPC client for worker (fail fast if can't connect) + let client = match self.get_or_create_grpc_client(worker.url()).await { + Ok(c) => c, + Err(e) => { + error!("Failed to get gRPC client: {}", e); + return ( + StatusCode::INTERNAL_SERVER_ERROR, + format!("Failed to get gRPC client: {}", e), + ) + .into_response(); + } + }; + + // Step 3: Process messages and apply chat template + let processed_messages = match self.process_chat_messages(body) { + Ok(msgs) => msgs, + Err(e) => { + error!("Failed to process chat messages: {}", e); + return (StatusCode::BAD_REQUEST, e.to_string()).into_response(); + } + }; + + // Step 4: Tokenize the processed text + let encoding = match self.tokenizer.encode(&processed_messages.text) { + Ok(encoding) => encoding, + Err(e) => { + error!("Tokenization failed: {}", e); + return ( + StatusCode::INTERNAL_SERVER_ERROR, + format!("Tokenization failed: {}", e), + ) + .into_response(); + } + }; + + let token_ids = encoding.token_ids().to_vec(); + debug!("Tokenized {} tokens from input", token_ids.len()); + + // Step 5: Build tool constraints if needed + let structural_tag = if let Some(tools) = &body.tools { + self.generate_tool_constraints(tools, &body.tool_choice, &body.model) + } else { + None + }; + + // Step 6: Build SamplingParams for gRPC + let sampling_params = match self.build_grpc_sampling_params(body, structural_tag) { + Ok(params) => params, + Err(e) => { + error!("Failed to build sampling parameters: {}", e); + return ( + StatusCode::BAD_REQUEST, + format!("Invalid sampling parameters: {}", e), + ) + .into_response(); + } + }; + + // Step 7: Create GenerateRequest + let grpc_request = proto::GenerateRequest { + request_id: format!("chatcmpl-{}", Uuid::new_v4()), + tokenized: Some(proto::TokenizedInput { + original_text: processed_messages.text.clone(), + input_ids: token_ids.into_iter().map(|id| id as i32).collect(), + }), + mm_inputs: processed_messages.multimodal_inputs, + sampling_params: Some(sampling_params), + return_logprob: body.logprobs, + logprob_start_len: -1, + top_logprobs_num: body.top_logprobs.unwrap_or(0) as i32, + return_hidden_states: body.return_hidden_states, + ..Default::default() + }; + + // Step 8: Handle streaming vs non-streaming + if body.stream { + self.handle_streaming_chat(client, grpc_request, body).await + } else { + self.handle_non_streaming_chat(client, grpc_request, body) + .await + } + } + + // ============ Helper Methods ============ + + /// Process chat messages and apply template + fn process_chat_messages( + &self, + request: &ChatCompletionRequest, + ) -> Result { + let tokenizer_messages = self.convert_messages_for_tokenizer(&request.messages)?; + + // Use the tokenizer's chat template - we require HuggingFace tokenizer for gRPC + let formatted_text = if let Some(hf_tokenizer) = + self.tokenizer + .as_any() + .downcast_ref::() + { + hf_tokenizer + .apply_chat_template(&tokenizer_messages, true) + .map_err(|e| format!("Failed to apply chat template: {}", e))? + } else { + return Err( + "gRPC router requires HuggingFace tokenizer with chat template support".to_string(), + ); + }; + + // Placeholder for multimodal inputs + let multimodal_inputs = None; + + Ok(ProcessedMessages { + text: formatted_text, + multimodal_inputs, + stop_sequences: request.stop.clone(), + }) + } + + /// Convert spec ChatMessage enum to tokenizer ChatMessage struct + fn convert_messages_for_tokenizer( + &self, + messages: &[ChatMessage], + ) -> Result, String> { + let mut converted = Vec::new(); + + for message in messages { + let tokenizer_msg = match message { + ChatMessage::System { content, .. } => TokenizerChatMessage::new("system", content), + ChatMessage::User { content, .. } => { + let text_content = match content { + UserMessageContent::Text(text) => text.clone(), + UserMessageContent::Parts(parts) => { + // Simple text extraction for now - multimodal is placeholder + parts + .iter() + .filter_map(|part| match part { + ContentPart::Text { text } => Some(text.as_str()), + ContentPart::ImageUrl { .. } => None, // Skip images for now + }) + .collect::>() + .join(" ") + } + }; + TokenizerChatMessage::new("user", text_content) + } + ChatMessage::Assistant { content, .. } => { + // Simple content extraction - no special tool/reasoning formatting + TokenizerChatMessage::new("assistant", content.as_deref().unwrap_or("")) + } + ChatMessage::Tool { content, .. } => TokenizerChatMessage::new("tool", content), + ChatMessage::Function { content, .. } => { + TokenizerChatMessage::new("function", content) + } + }; + converted.push(tokenizer_msg); + } + + Ok(converted) + } + + /// Build gRPC SamplingParams from OpenAI request + fn build_grpc_sampling_params( + &self, + request: &ChatCompletionRequest, + structural_tag: Option, + ) -> Result { + let stop_sequences = self.extract_stop_strings(request); + + // Handle max tokens: prefer max_completion_tokens (new) over max_tokens (deprecated) + // If neither is specified, use None to let the backend decide the default + #[allow(deprecated)] + let max_new_tokens = request + .max_completion_tokens + .or(request.max_tokens) + .map(|v| v as i32); + + #[allow(deprecated)] + Ok(proto::SamplingParams { + temperature: request.temperature.unwrap_or(1.0), + top_p: request.top_p.unwrap_or(1.0), + top_k: request.top_k.unwrap_or(-1), + min_p: request.min_p.unwrap_or(0.0), + frequency_penalty: request.frequency_penalty.unwrap_or(0.0), + presence_penalty: request.presence_penalty.unwrap_or(0.0), + repetition_penalty: request.repetition_penalty.unwrap_or(1.0), + max_new_tokens, + stop: stop_sequences, + stop_token_ids: request.stop_token_ids.clone().unwrap_or_default(), + skip_special_tokens: request.skip_special_tokens, + n: request.n.unwrap_or(1) as i32, + structural_tag: structural_tag.unwrap_or_default(), + constraint: self.build_constraint(request)?, + ..Default::default() + }) + } + + /// Extract stop strings from request + fn extract_stop_strings(&self, request: &ChatCompletionRequest) -> Vec { + match &request.stop { + Some(StringOrArray::String(s)) => vec![s.clone()], + Some(StringOrArray::Array(arr)) => arr.clone(), + None => vec![], + } + } + + /// Build constraint for structured generation + fn build_constraint( + &self, + request: &ChatCompletionRequest, + ) -> Result, String> { + if let Some(ResponseFormat::JsonSchema { json_schema }) = &request.response_format { + let schema_str = serde_json::to_string(&json_schema.schema) + .map_err(|e| format!("Failed to serialize JSON schema: {}", e))?; + return Ok(Some(proto::sampling_params::Constraint::JsonSchema( + schema_str, + ))); + } + + if let Some(ebnf) = &request.ebnf { + return Ok(Some(proto::sampling_params::Constraint::EbnfGrammar( + ebnf.clone(), + ))); + } + + if let Some(regex) = &request.regex { + return Ok(Some(proto::sampling_params::Constraint::Regex( + regex.clone(), + ))); + } + + Ok(None) + } + + /// Generate tool constraints for structured generation + fn generate_tool_constraints( + &self, + _tools: &[crate::protocols::spec::Tool], + _tool_choice: &Option, + model: &str, + ) -> Option { + let _parser = self.tool_parser_registry.get_parser(model)?; + None + } + + /// Select a worker for the request + fn select_worker_for_request( + &self, + model_id: Option<&str>, + text: Option<&str>, + ) -> Option> { + // Get workers for the specified model, filtered by connection mode + let workers = self.worker_registry.get_workers_filtered( + model_id, + Some(WorkerType::Regular), + Some(crate::core::ConnectionMode::Grpc { port: None }), + false, // get all workers, we'll filter by is_available() next + ); + + // Filter by availability (health + circuit breaker) + let available: Vec> = workers + .iter() + .filter(|w| w.is_available()) + .cloned() + .collect(); + + if available.is_empty() { + return None; + } + + // Get the appropriate policy for this model + let policy = match model_id { + Some(model) => self.policy_registry.get_policy_or_default(model), + None => self.policy_registry.get_default_policy(), + }; + + // Select worker using the policy + let idx = policy.select_worker(&available, text)?; + Some(available[idx].clone()) + } + + /// Get or create a gRPC client for the worker + async fn get_or_create_grpc_client( + &self, + worker_url: &str, + ) -> Result { + debug!("Creating new gRPC client for worker: {}", worker_url); + SglangSchedulerClient::connect(worker_url) + .await + .map_err(|e| format!("Failed to connect to gRPC server: {}", e)) + } + + /// Placeholder for streaming handler (to be implemented in Phase 2) + async fn handle_streaming_chat( + &self, + _client: SglangSchedulerClient, + _request: proto::GenerateRequest, + _original_request: &ChatCompletionRequest, + ) -> Response { + (StatusCode::NOT_IMPLEMENTED, "Streaming not yet implemented").into_response() + } + + /// Placeholder for non-streaming handler (to be implemented in Phase 3) + async fn handle_non_streaming_chat( + &self, + _client: SglangSchedulerClient, + _request: proto::GenerateRequest, + _original_request: &ChatCompletionRequest, + ) -> Response { + ( + StatusCode::NOT_IMPLEMENTED, + "Non-streaming not yet implemented", + ) + .into_response() + } } impl std::fmt::Debug for GrpcRouter { @@ -212,11 +564,11 @@ impl RouterTrait for GrpcRouter { async fn route_chat( &self, - _headers: Option<&HeaderMap>, - _body: &crate::protocols::spec::ChatCompletionRequest, - _model_id: Option<&str>, + headers: Option<&HeaderMap>, + body: &crate::protocols::spec::ChatCompletionRequest, + model_id: Option<&str>, ) -> Response { - (StatusCode::NOT_IMPLEMENTED).into_response() + self.route_chat_impl(headers, body, model_id).await } async fn route_completion( diff --git a/sgl-router/src/tokenizer/huggingface.rs b/sgl-router/src/tokenizer/huggingface.rs index 02dce5a0a..7cb930d18 100644 --- a/sgl-router/src/tokenizer/huggingface.rs +++ b/sgl-router/src/tokenizer/huggingface.rs @@ -210,6 +210,10 @@ impl TokenizerTrait for HuggingFaceTokenizer { fn id_to_token(&self, id: TokenIdType) -> Option { self.reverse_vocab.get(&id).cloned() } + + fn as_any(&self) -> &dyn std::any::Any { + self + } } #[cfg(test)] diff --git a/sgl-router/src/tokenizer/mock.rs b/sgl-router/src/tokenizer/mock.rs index afb91543c..9b0cd5cdf 100644 --- a/sgl-router/src/tokenizer/mock.rs +++ b/sgl-router/src/tokenizer/mock.rs @@ -109,4 +109,8 @@ impl TokenizerTrait for MockTokenizer { fn id_to_token(&self, id: u32) -> Option { self.reverse_vocab.get(&id).cloned() } + + fn as_any(&self) -> &dyn std::any::Any { + self + } } diff --git a/sgl-router/src/tokenizer/tiktoken.rs b/sgl-router/src/tokenizer/tiktoken.rs index 0af5a9791..74607b419 100644 --- a/sgl-router/src/tokenizer/tiktoken.rs +++ b/sgl-router/src/tokenizer/tiktoken.rs @@ -170,6 +170,10 @@ impl TokenizerTrait for TiktokenTokenizer { // We can only decode IDs to text None } + + fn as_any(&self) -> &dyn std::any::Any { + self + } } #[cfg(test)] diff --git a/sgl-router/src/tokenizer/traits.rs b/sgl-router/src/tokenizer/traits.rs index 275dd822f..3ef2c4fe0 100644 --- a/sgl-router/src/tokenizer/traits.rs +++ b/sgl-router/src/tokenizer/traits.rs @@ -22,6 +22,9 @@ pub trait Tokenizer: Encoder + Decoder { fn get_special_tokens(&self) -> &SpecialTokens; fn token_to_id(&self, token: &str) -> Option; fn id_to_token(&self, id: TokenIdType) -> Option; + + /// Enable downcasting to concrete types + fn as_any(&self) -> &dyn std::any::Any; } /// Contains the results of tokenizing text: token IDs, string tokens, and their spans