From 33b3c0f85ffb647a1fc831c59c112bcfca5c06b8 Mon Sep 17 00:00:00 2001 From: Simo Lin Date: Tue, 30 Sep 2025 01:07:53 -0400 Subject: [PATCH] [router] grpc router generate endpoint support (#11070) Co-authored-by: Chang Su --- .../src/grpc_client/sglang_scheduler.rs | 135 +++++- sgl-router/src/routers/grpc/router.rs | 434 ++++++++++++++---- 2 files changed, 480 insertions(+), 89 deletions(-) diff --git a/sgl-router/src/grpc_client/sglang_scheduler.rs b/sgl-router/src/grpc_client/sglang_scheduler.rs index 36a980235..4cb082b53 100644 --- a/sgl-router/src/grpc_client/sglang_scheduler.rs +++ b/sgl-router/src/grpc_client/sglang_scheduler.rs @@ -1,8 +1,12 @@ +use std::convert::TryFrom; use std::time::Duration; use tonic::{transport::Channel, Request}; use tracing::debug; -use crate::protocols::spec::{ChatCompletionRequest, ResponseFormat}; +use crate::protocols::spec::{ + ChatCompletionRequest, GenerateRequest, ResponseFormat, + SamplingParams as GenerateSamplingParams, StringOrArray, +}; // Include the generated protobuf code pub mod proto { @@ -112,6 +116,37 @@ impl SglangSchedulerClient { Ok(grpc_request) } + /// Build a basic GenerateRequest from the SGLang spec GenerateRequest + pub fn build_plain_generate_request( + &self, + request_id: String, + body: &GenerateRequest, + original_text: Option, + token_ids: Vec, + ) -> Result { + let sampling_params = + Self::build_sampling_params_from_plain(body.sampling_params.as_ref())?; + + let grpc_request = proto::GenerateRequest { + request_id, + tokenized: Some(proto::TokenizedInput { + original_text: original_text.unwrap_or_default(), + input_ids: token_ids, + }), + sampling_params: Some(sampling_params), + return_logprob: body.return_logprob, + logprob_start_len: -1, + top_logprobs_num: 0, + token_ids_logprob: vec![], + return_hidden_states: body.return_hidden_states, + stream: body.stream, + log_metrics: true, + ..Default::default() + }; + + Ok(grpc_request) + } + /// Build gRPC SamplingParams from OpenAI request fn build_grpc_sampling_params( &self, @@ -165,8 +200,8 @@ impl SglangSchedulerClient { /// Extract stop strings from request fn extract_stop_strings(&self, request: &ChatCompletionRequest) -> Vec { match &request.stop { - Some(crate::protocols::spec::StringOrArray::String(s)) => vec![s.clone()], - Some(crate::protocols::spec::StringOrArray::Array(arr)) => arr.clone(), + Some(StringOrArray::String(s)) => vec![s.clone()], + Some(StringOrArray::Array(arr)) => arr.clone(), None => vec![], } } @@ -218,6 +253,100 @@ impl SglangSchedulerClient { _ => Err("Multiple constraints are not allowed.".to_string()), } } + + fn build_single_constraint_from_plain( + params: &GenerateSamplingParams, + ) -> Result, String> { + let mut constraints = Vec::new(); + if let Some(json_schema) = ¶ms.json_schema { + constraints.push(proto::sampling_params::Constraint::JsonSchema( + json_schema.clone(), + )); + } + if let Some(regex) = ¶ms.regex { + constraints.push(proto::sampling_params::Constraint::Regex(regex.clone())); + } + if let Some(ebnf) = ¶ms.ebnf { + constraints.push(proto::sampling_params::Constraint::EbnfGrammar( + ebnf.clone(), + )); + } + + match constraints.len() { + 0 => Ok(None), + 1 => Ok(constraints.pop()), + _ => Err("Multiple structured constraints are not allowed".to_string()), + } + } + + fn build_sampling_params_from_plain( + params: Option<&GenerateSamplingParams>, + ) -> Result { + let mut sampling = proto::SamplingParams { + temperature: 1.0, + top_p: 1.0, + top_k: -1, + repetition_penalty: 1.0, + n: 1, + ..Default::default() + }; + + let Some(p) = params else { + return Ok(sampling); + }; + + // Simple field mappings using a macro + macro_rules! map_field { + ($field:ident) => { + if let Some(val) = p.$field { + sampling.$field = val; + } + }; + } + + map_field!(temperature); + map_field!(top_p); + map_field!(top_k); + map_field!(frequency_penalty); + map_field!(presence_penalty); + map_field!(repetition_penalty); + map_field!(min_p); + map_field!(ignore_eos); + map_field!(skip_special_tokens); + map_field!(no_stop_trim); + + // Handle stop sequences + if let Some(stop) = &p.stop { + match stop { + StringOrArray::String(s) => sampling.stop.push(s.clone()), + StringOrArray::Array(arr) => sampling.stop.extend(arr.clone()), + } + } + + // Handle stop token IDs + if let Some(stop_token_ids) = &p.stop_token_ids { + sampling.stop_token_ids = stop_token_ids.clone(); + } + + // Handle max_new_tokens with conversion + if let Some(max_new_tokens) = p.max_new_tokens { + sampling.max_new_tokens = + Some(i32::try_from(max_new_tokens).map_err(|_| { + "max_new_tokens must fit into a 32-bit signed integer".to_string() + })?); + } + + // Handle min_tokens with conversion + if let Some(min_tokens) = p.min_tokens { + sampling.min_new_tokens = i32::try_from(min_tokens) + .map_err(|_| "min_tokens must fit into a 32-bit signed integer".to_string())?; + } + + // Handle constraints (exactly one allowed) + sampling.constraint = Self::build_single_constraint_from_plain(p)?; + + Ok(sampling) + } } #[cfg(test)] diff --git a/sgl-router/src/routers/grpc/router.rs b/sgl-router/src/routers/grpc/router.rs index dce2ca6f7..483a8127f 100644 --- a/sgl-router/src/routers/grpc/router.rs +++ b/sgl-router/src/routers/grpc/router.rs @@ -27,12 +27,15 @@ use crate::reasoning_parser::ParserFactory; use crate::routers::RouterTrait; use crate::server::AppContext; use crate::tokenizer::chat_template::{ChatTemplateContentFormat, ChatTemplateParams}; -use crate::tokenizer::stop::{SequenceDecoderOutput, StopSequenceDecoderBuilder}; +use crate::tokenizer::stop::{ + SequenceDecoderOutput, StopSequenceDecoder, StopSequenceDecoderBuilder, +}; use crate::tokenizer::traits::Tokenizer; use crate::tokenizer::HuggingFaceTokenizer; use crate::tool_parser::ParserRegistry; -use serde_json::Value; -use std::time::{SystemTime, UNIX_EPOCH}; +use proto::generate_response::Response::{Chunk, Complete, Error}; +use serde_json::{json, Value}; +use std::time::{Instant, SystemTime, UNIX_EPOCH}; use tokio_stream::StreamExt; use uuid::Uuid; @@ -124,28 +127,9 @@ impl GrpcRouter { debug!("Selected worker: {}", worker.url()); // Step 2: Get gRPC client from worker - let client = match worker.get_grpc_client().await { - Ok(Some(client_arc)) => { - // Clone the client from inside the Arc> - let client = client_arc.lock().await.clone(); - client - } - Ok(None) => { - error!("Selected worker is not a gRPC worker"); - return ( - StatusCode::INTERNAL_SERVER_ERROR, - "Selected worker is not configured for gRPC", - ) - .into_response(); - } - Err(e) => { - error!("Failed to get gRPC client from worker: {}", e); - return ( - StatusCode::INTERNAL_SERVER_ERROR, - format!("Failed to get gRPC client: {}", e), - ) - .into_response(); - } + let client = match Self::get_grpc_client_from_worker(&worker).await { + Ok(client) => client, + Err(response) => return response, }; // Step 3: Process messages and apply chat template @@ -209,6 +193,112 @@ impl GrpcRouter { } } + /// Main route_generate implementation + async fn route_generate_impl( + &self, + _headers: Option<&HeaderMap>, + body: &GenerateRequest, + model_id: Option<&str>, + ) -> Response { + debug!("Processing generate request for model: {:?}", model_id); + + // Step 1: Resolve input (text, prompt, or input_ids) + let (original_text, token_ids) = match self.resolve_generate_input(body) { + Ok(res) => res, + Err(msg) => { + error!("Invalid generate request: {}", msg); + return (StatusCode::BAD_REQUEST, msg).into_response(); + } + }; + + debug!("Resolved input with {} tokens", token_ids.len()); + + // Step 2: Select worker (fail fast if no workers available) + let worker = match self.select_worker_for_request(model_id, original_text.as_deref()) { + Some(w) => w, + None => { + warn!("No available workers for model: {:?}", model_id); + return (StatusCode::SERVICE_UNAVAILABLE, "No available workers").into_response(); + } + }; + + debug!("Selected worker: {}", worker.url()); + + // Step 3: Get gRPC client from worker + let client = match Self::get_grpc_client_from_worker(&worker).await { + Ok(client) => client, + Err(response) => return response, + }; + + // Step 4: Build the gRPC request + let request_id = body + .rid + .clone() + .unwrap_or_else(|| format!("gen-{}", Uuid::new_v4())); + + let request = match client.build_plain_generate_request( + request_id.clone(), + body, + original_text.clone(), + token_ids, + ) { + Ok(req) => req, + Err(e) => { + error!("Failed to build generate request: {}", e); + return (StatusCode::BAD_REQUEST, e).into_response(); + } + }; + + // Step 5: Get weight version for response metadata + let weight_version = worker + .metadata() + .labels + .get("weight_version") + .cloned() + .unwrap_or_else(|| "default".to_string()); + + // Step 6: Handle streaming vs non-streaming + if body.stream { + // TODO: Implement streaming support for generate endpoint + return ( + StatusCode::NOT_IMPLEMENTED, + "Streaming generate over gRPC is not supported yet", + ) + .into_response(); + } + + self.handle_non_streaming_generate(client, request, body, request_id, weight_version) + .await + } + + /// Get gRPC client from worker, returning appropriate error response on failure + async fn get_grpc_client_from_worker( + worker: &Arc, + ) -> Result { + let client_arc = worker + .get_grpc_client() + .await + .map_err(|e| { + error!("Failed to get gRPC client from worker: {}", e); + ( + StatusCode::INTERNAL_SERVER_ERROR, + format!("Failed to get gRPC client: {}", e), + ) + .into_response() + })? + .ok_or_else(|| { + error!("Selected worker is not a gRPC worker"); + ( + StatusCode::INTERNAL_SERVER_ERROR, + "Selected worker is not configured for gRPC", + ) + .into_response() + })?; + + let client = client_arc.lock().await.clone(); + Ok(client) + } + /// Select a worker for the request fn select_worker_for_request( &self, @@ -265,7 +355,7 @@ impl GrpcRouter { Self::process_tool_call_arguments(&mut transformed_messages)?; // Convert tools to JSON values for template processing - let tools_json: Option> = request + let tools_json: Option> = request .tools .as_ref() .map(|tools| { @@ -284,7 +374,7 @@ impl GrpcRouter { if let Some(reasoning_effort) = &request.reasoning_effort { combined_template_kwargs.insert( "reasoning_effort".to_string(), - serde_json::Value::String(reasoning_effort.clone()), + Value::String(reasoning_effort.clone()), ); } @@ -413,9 +503,9 @@ impl GrpcRouter { part.as_object() .and_then(|obj| obj.get("type")?.as_str()) .and_then(|type_str| match type_str { - "image_url" => Some(serde_json::json!({"type": "image"})), - "video_url" => Some(serde_json::json!({"type": "video"})), - "audio_url" => Some(serde_json::json!({"type": "audio"})), + "image_url" => Some(json!({"type": "image"})), + "video_url" => Some(json!({"type": "video"})), + "audio_url" => Some(json!({"type": "audio"})), _ => None, }) .unwrap_or_else(|| part.clone()) @@ -456,7 +546,7 @@ impl GrpcRouter { }; // Parse JSON string to object (like Python json.loads) - match serde_json::from_str::(args_str) { + match serde_json::from_str::(args_str) { Ok(parsed) => *args = parsed, Err(e) => { return Err(format!( @@ -483,13 +573,63 @@ impl GrpcRouter { None } - /// Create a StopSequenceDecoder from the chat completion request + /// Resolve the generate input into optional original text and token IDs + fn resolve_generate_input( + &self, + request: &GenerateRequest, + ) -> Result<(Option, Vec), String> { + if let Some(text) = &request.text { + return self + .tokenize_single_text(text) + .map(|(original, ids)| (Some(original), ids)); + } + + // Handle input_ids - validate and convert + if let Some(input_ids) = &request.input_ids { + return match input_ids { + crate::protocols::spec::InputIds::Single(ids) => ids + .iter() + .map(|&id| u32::try_from(id)) + .collect::, _>>() + .map(|converted| (None, converted)) + .map_err(|_| "input_ids must be non-negative".to_string()), + crate::protocols::spec::InputIds::Batch(_) => { + Err("Batch input_ids are not supported over gRPC generate yet".to_string()) + } + }; + } + + Err("Either `text` or `input_ids` must be provided".to_string()) + } + + fn tokenize_single_text(&self, text: &str) -> Result<(String, Vec), String> { + let encoding = self + .tokenizer + .encode(text) + .map_err(|e| format!("Tokenization failed: {}", e))?; + Ok((text.to_string(), encoding.token_ids().to_vec())) + } + + fn internal_error_static(msg: &'static str) -> Response { + error!("{}", msg); + (StatusCode::INTERNAL_SERVER_ERROR, msg).into_response() + } + + fn internal_error_message(message: String) -> Response { + error!("{}", message); + (StatusCode::INTERNAL_SERVER_ERROR, message).into_response() + } + + /// Create a StopSequenceDecoder from stop parameters fn create_stop_decoder( &self, - original_request: &ChatCompletionRequest, - ) -> crate::tokenizer::stop::StopSequenceDecoder { - // Extract stop sequences from request - let stop_sequences: Vec = match &original_request.stop { + stop: Option<&StringOrArray>, + stop_token_ids: Option<&Vec>, + skip_special_tokens: bool, + no_stop_trim: bool, + ) -> StopSequenceDecoder { + // Extract stop sequences + let stop_sequences: Vec = match stop { Some(StringOrArray::String(s)) => vec![s.clone()], Some(StringOrArray::Array(arr)) => arr.clone(), None => vec![], @@ -497,11 +637,11 @@ impl GrpcRouter { // Build stop sequence decoder let mut builder = StopSequenceDecoderBuilder::new(self.tokenizer.clone()) - .skip_special_tokens(original_request.skip_special_tokens); + .skip_special_tokens(skip_special_tokens); // Add stop sequences (visible if no_stop_trim is true, hidden otherwise) for seq in stop_sequences { - builder = if original_request.no_stop_trim { + builder = if no_stop_trim { builder.visible_stop_sequence(seq) } else { builder.stop_sequence(seq) @@ -509,9 +649,9 @@ impl GrpcRouter { } // Add stop token IDs (visible if no_stop_trim is true, hidden otherwise) - if let Some(stop_token_ids) = &original_request.stop_token_ids { - for &token_id in stop_token_ids { - builder = if original_request.no_stop_trim { + if let Some(token_ids) = stop_token_ids { + for &token_id in token_ids { + builder = if no_stop_trim { builder.visible_stop_token(token_id) } else { builder.stop_token(token_id) @@ -524,7 +664,7 @@ impl GrpcRouter { /// Process a chunk of tokens through the stop decoder fn process_chunk_tokens( - stop_decoder: &mut crate::tokenizer::stop::StopSequenceDecoder, + stop_decoder: &mut StopSequenceDecoder, token_ids: &[u32], ) -> (String, bool) { let mut chunk_text = String::new(); @@ -562,7 +702,12 @@ impl GrpcRouter { request: proto::GenerateRequest, original_request: &ChatCompletionRequest, ) -> Response { - let mut stop_decoder = self.create_stop_decoder(original_request); + let mut stop_decoder = self.create_stop_decoder( + original_request.stop.as_ref(), + original_request.stop_token_ids.as_ref(), + original_request.skip_special_tokens, + original_request.no_stop_trim, + ); // Process streaming tokens let mut grpc_stream = match client.generate(request).await { @@ -589,7 +734,7 @@ impl GrpcRouter { }; match gen_response.response { - Some(proto::generate_response::Response::Chunk(chunk)) => { + Some(Chunk(chunk)) => { // Process tokens and check if we should stop let (chunk_text, should_stop) = Self::process_chunk_tokens(&mut stop_decoder, &chunk.token_ids); @@ -599,7 +744,7 @@ impl GrpcRouter { } continue; } - Some(proto::generate_response::Response::Complete(_complete)) => { + Some(Complete(_complete)) => { // Flush any remaining text if let SequenceDecoderOutput::Text(text) = stop_decoder.flush() { if !text.is_empty() { @@ -609,7 +754,7 @@ impl GrpcRouter { } break; } - Some(proto::generate_response::Response::Error(error)) => { + Some(Error(error)) => { error!("Generation error: {}", error.message); break; } @@ -629,26 +774,19 @@ impl GrpcRouter { request: proto::GenerateRequest, original_request: &ChatCompletionRequest, ) -> Response { - let mut stop_decoder = self.create_stop_decoder(original_request); - - // Small helpers to log + return a uniform 500 - let fail_str = |msg: &'static str| -> Response { - error!("{}", msg); - (StatusCode::INTERNAL_SERVER_ERROR, msg).into_response() - }; - let fail_fmt = |prefix: &str, e: &dyn std::fmt::Display| -> Response { - error!("{}{}", prefix, e); - ( - StatusCode::INTERNAL_SERVER_ERROR, - format!("{}{}", prefix, e), - ) - .into_response() - }; + let mut stop_decoder = self.create_stop_decoder( + original_request.stop.as_ref(), + original_request.stop_token_ids.as_ref(), + original_request.skip_special_tokens, + original_request.no_stop_trim, + ); // Start generation let mut stream = match client.generate(request).await { Ok(s) => s, - Err(e) => return fail_fmt("Failed to start generation: ", &e), + Err(e) => { + return Self::internal_error_message(format!("Failed to start generation: {}", e)) + } }; // Collect all responses (for n>1 support) @@ -656,28 +794,33 @@ impl GrpcRouter { while let Some(response) = stream.next().await { match response { Ok(gen_response) => match gen_response.response { - Some(proto::generate_response::Response::Complete(complete)) => { + Some(Complete(complete)) => { all_responses.push(complete); } - Some(proto::generate_response::Response::Error(err)) => { - error!("Generation failed for one choice: {}", err.message); - return ( - StatusCode::INTERNAL_SERVER_ERROR, - format!("Generation failed: {}", err.message), + Some(Error(err)) => { + return Self::internal_error_message(format!( + "Generation failed: {}", + err.message + )); + } + Some(Chunk(_)) => { + return Self::internal_error_static( + "Unexpected chunk response for non-streaming request", ) - .into_response(); } - Some(proto::generate_response::Response::Chunk(_)) => { - return fail_str("Unexpected chunk response for non-streaming request") - } - None => return fail_str("Empty response from server"), + None => return Self::internal_error_static("Empty response from server"), }, - Err(e) => return fail_fmt("Failed to get GenerateResponse: ", &e), + Err(e) => { + return Self::internal_error_message(format!( + "Failed to get GenerateResponse: {}", + e + )) + } } } if all_responses.is_empty() { - return fail_str("No responses from server"); + return Self::internal_error_static("No responses from server"); } // Process each response into a ChatChoice @@ -689,12 +832,10 @@ impl GrpcRouter { { Ok(choice) => choices.push(choice), Err(e) => { - error!("Failed to process choice {}: {}", index, e); - return ( - StatusCode::INTERNAL_SERVER_ERROR, - format!("Failed to process choice {}: {}", index, e), - ) - .into_response(); + return Self::internal_error_message(format!( + "Failed to process choice {}: {}", + index, e + )); } } } @@ -730,6 +871,127 @@ impl GrpcRouter { Json(response).into_response() } + /// Submit request and handle non-streaming response for the `/generate` endpoint + async fn handle_non_streaming_generate( + &self, + mut client: SglangSchedulerClient, + request: proto::GenerateRequest, + original_request: &GenerateRequest, + request_id: String, + weight_version: String, + ) -> Response { + let start_time = Instant::now(); + + let mut stream = match client.generate(request).await { + Ok(stream) => stream, + Err(e) => { + return Self::internal_error_message(format!("Failed to start generation: {}", e)) + } + }; + + let mut final_completion: Option = None; + + while let Some(result) = stream.next().await { + match result { + Ok(gen_response) => match gen_response.response { + Some(Complete(complete)) => { + final_completion = Some(complete); + break; + } + Some(Error(err)) => { + return Self::internal_error_message(format!( + "Generation failed: {}", + err.message + )); + } + Some(Chunk(_)) | None => continue, + }, + Err(e) => { + return Self::internal_error_message(format!( + "Failed to receive generate response: {}", + e + )) + } + } + } + + let mut complete = match final_completion { + Some(c) => c, + None => { + return Self::internal_error_static("No completion received from scheduler"); + } + }; + + // Create stop decoder from sampling params + let params = original_request.sampling_params.as_ref(); + let mut stop_decoder = self.create_stop_decoder( + params.and_then(|p| p.stop.as_ref()), + params.and_then(|p| p.stop_token_ids.as_ref()), + params.and_then(|p| p.skip_special_tokens).unwrap_or(true), + params.and_then(|p| p.no_stop_trim).unwrap_or(false), + ); + + // Process tokens through stop decoder + let outputs = match stop_decoder.process_tokens(&complete.output_ids) { + Ok(outputs) => outputs, + Err(e) => { + return Self::internal_error_message(format!("Failed to process tokens: {}", e)) + } + }; + + // Accumulate text with early breaks + let mut decoded_text = String::new(); + for output in outputs { + match output { + SequenceDecoderOutput::Text(t) => decoded_text.push_str(&t), + SequenceDecoderOutput::StoppedWithText(t) => { + decoded_text.push_str(&t); + break; + } + SequenceDecoderOutput::Stopped => break, + SequenceDecoderOutput::Held => {} + } + } + + // Flush remaining text + if let SequenceDecoderOutput::Text(t) = stop_decoder.flush() { + decoded_text.push_str(&t); + } + + let output_ids = complete.output_ids.clone(); + + // Build base meta_info using json! macro + let mut meta_info = json!({ + "finish_reason": complete.finish_reason.clone(), + "prompt_tokens": complete.prompt_tokens, + "completion_tokens": complete.completion_tokens, + "cached_tokens": complete.cached_tokens, + "id": request_id, + "weight_version": weight_version, + "e2e_latency": start_time.elapsed().as_secs_f64(), + }); + + let meta_obj = meta_info.as_object_mut().unwrap(); + + // Add matched_stop if present + if let Some(matched) = complete.matched_stop.take() { + use proto::generate_complete::MatchedStop; + let matched_value = match matched { + MatchedStop::MatchedTokenId(id) => json!(id), + MatchedStop::MatchedStopStr(s) => json!(s), + }; + meta_obj.insert("matched_stop".to_string(), matched_value); + } + + let response_body = json!({ + "text": decoded_text, + "output_ids": output_ids, + "meta_info": meta_info, + }); + + Json(response_body).into_response() + } + /// Convert proto LogProbs to OpenAI ChatLogProbs format /// Note: Always decodes with skip_special_tokens=false to show actual tokens generated fn convert_proto_to_openai_logprobs( @@ -803,7 +1065,7 @@ impl GrpcRouter { complete: &proto::GenerateComplete, index: usize, original_request: &ChatCompletionRequest, - stop_decoder: &mut crate::tokenizer::stop::StopSequenceDecoder, + stop_decoder: &mut StopSequenceDecoder, ) -> Result { stop_decoder.reset(); // Decode tokens @@ -1002,11 +1264,11 @@ impl RouterTrait for GrpcRouter { async fn route_generate( &self, - _headers: Option<&HeaderMap>, - _body: &GenerateRequest, - _model_id: Option<&str>, + headers: Option<&HeaderMap>, + body: &GenerateRequest, + model_id: Option<&str>, ) -> Response { - (StatusCode::NOT_IMPLEMENTED).into_response() + self.route_generate_impl(headers, body, model_id).await } async fn route_chat(