From 0b9915c1329bec33fd1606679fc8cf57ee44190c Mon Sep 17 00:00:00 2001 From: Simo Lin Date: Tue, 14 Oct 2025 02:51:33 -0400 Subject: [PATCH] [router] update generate spec to align with sgl io struct (#11591) --- sgl-router/benches/request_processing.rs | 36 +++- .../src/grpc_client/sglang_scheduler.rs | 10 +- sgl-router/src/protocols/spec.rs | 156 +++++++++++++++--- sgl-router/src/routers/grpc/pipeline.rs | 2 +- sgl-router/src/routers/grpc/streaming.rs | 2 +- sgl-router/src/routers/http/pd_router.rs | 17 +- sgl-router/tests/test_openai_routing.rs | 32 +++- 7 files changed, 196 insertions(+), 59 deletions(-) diff --git a/sgl-router/benches/request_processing.rs b/sgl-router/benches/request_processing.rs index 703ca55fd..b2d6d7430 100644 --- a/sgl-router/benches/request_processing.rs +++ b/sgl-router/benches/request_processing.rs @@ -28,15 +28,38 @@ fn get_bootstrap_info(worker: &BasicWorker) -> (String, Option) { fn default_generate_request() -> GenerateRequest { GenerateRequest { text: None, - prompt: None, input_ids: None, - stream: false, + input_embeds: None, + image_data: None, + video_data: None, + audio_data: None, sampling_params: None, - return_logprob: false, - // SGLang Extensions - lora_path: None, - session_params: None, + return_logprob: None, + logprob_start_len: None, + top_logprobs_num: None, + token_ids_logprob: None, + return_text_in_logprobs: false, + stream: false, + log_metrics: true, return_hidden_states: false, + modalities: None, + session_params: None, + lora_path: None, + lora_id: None, + custom_logit_processor: None, + bootstrap_host: None, + bootstrap_port: None, + bootstrap_room: None, + bootstrap_pair_key: None, + data_parallel_rank: None, + background: false, + conversation_id: None, + priority: None, + extra_key: None, + no_logs: false, + custom_labels: None, + return_bytes: false, + return_entropy: false, rid: None, } } @@ -101,6 +124,7 @@ fn create_sample_generate_request() -> GenerateRequest { GenerateRequest { text: Some("Write a story about artificial intelligence".to_string()), sampling_params: Some(SamplingParams { + max_new_tokens: Some(100), temperature: Some(0.8), top_p: Some(0.9), top_k: Some(50), diff --git a/sgl-router/src/grpc_client/sglang_scheduler.rs b/sgl-router/src/grpc_client/sglang_scheduler.rs index 3ff74d303..9a5ef9a1f 100644 --- a/sgl-router/src/grpc_client/sglang_scheduler.rs +++ b/sgl-router/src/grpc_client/sglang_scheduler.rs @@ -280,13 +280,13 @@ impl SglangSchedulerClient { input_ids: token_ids, }), sampling_params: Some(sampling_params), - return_logprob: body.return_logprob, - logprob_start_len: -1, - top_logprobs_num: 0, - token_ids_logprob: vec![], + return_logprob: body.return_logprob.unwrap_or(false), + logprob_start_len: body.logprob_start_len.unwrap_or(-1), + top_logprobs_num: body.top_logprobs_num.unwrap_or(0), + token_ids_logprob: body.token_ids_logprob.clone().unwrap_or_default(), return_hidden_states: body.return_hidden_states, stream: body.stream, - log_metrics: true, + log_metrics: body.log_metrics, ..Default::default() }; diff --git a/sgl-router/src/protocols/spec.rs b/sgl-router/src/protocols/spec.rs index c7eb42c98..9c6dd0651 100644 --- a/sgl-router/src/protocols/spec.rs +++ b/sgl-router/src/protocols/spec.rs @@ -356,7 +356,7 @@ pub struct ChatCompletionRequest { /// Path to LoRA adapter(s) for model customization #[serde(skip_serializing_if = "Option::is_none")] - pub lora_path: Option, + pub lora_path: Option, /// Session parameters for continual prompting #[serde(skip_serializing_if = "Option::is_none")] @@ -905,7 +905,7 @@ pub struct CompletionRequest { /// Path to LoRA adapter(s) for model customization #[serde(skip_serializing_if = "Option::is_none")] - pub lora_path: Option, + pub lora_path: Option, /// Session parameters for continual prompting #[serde(skip_serializing_if = "Option::is_none")] @@ -2309,10 +2309,6 @@ fn validate_sampling_params(params: &SamplingParams) -> Result<(), validator::Va #[derive(Clone, Debug, Serialize, Deserialize, Validate)] #[validate(schema(function = "validate_generate_request"))] pub struct GenerateRequest { - /// The prompt to generate from (OpenAI style) - #[serde(skip_serializing_if = "Option::is_none")] - pub prompt: Option, - /// Text input - SGLang native format #[serde(skip_serializing_if = "Option::is_none")] pub text: Option, @@ -2321,31 +2317,144 @@ pub struct GenerateRequest { #[serde(skip_serializing_if = "Option::is_none")] pub input_ids: Option, + /// Input embeddings for direct embedding input + /// Can be a 2D array (single request) or 3D array (batch of requests) + /// Placeholder for future use + #[serde(skip_serializing_if = "Option::is_none")] + pub input_embeds: Option, + + /// Image input data + /// Can be an image instance, file name, URL, or base64 encoded string + /// Supports single images, lists of images, or nested lists for batch processing + /// Placeholder for future use + #[serde(skip_serializing_if = "Option::is_none")] + pub image_data: Option, + + /// Video input data + /// Can be a file name, URL, or base64 encoded string + /// Supports single videos, lists of videos, or nested lists for batch processing + /// Placeholder for future use + #[serde(skip_serializing_if = "Option::is_none")] + pub video_data: Option, + + /// Audio input data + /// Can be a file name, URL, or base64 encoded string + /// Supports single audio files, lists of audio, or nested lists for batch processing + /// Placeholder for future use + #[serde(skip_serializing_if = "Option::is_none")] + pub audio_data: Option, + /// Sampling parameters (sglang style) #[serde(skip_serializing_if = "Option::is_none")] pub sampling_params: Option, + /// Whether to return logprobs + #[serde(skip_serializing_if = "Option::is_none")] + pub return_logprob: Option, + + /// If return logprobs, the start location in the prompt for returning logprobs. + #[serde(skip_serializing_if = "Option::is_none")] + pub logprob_start_len: Option, + + /// If return logprobs, the number of top logprobs to return at each position. + #[serde(skip_serializing_if = "Option::is_none")] + pub top_logprobs_num: Option, + + /// If return logprobs, the token ids to return logprob for. + #[serde(skip_serializing_if = "Option::is_none")] + pub token_ids_logprob: Option>, + + /// Whether to detokenize tokens in text in the returned logprobs. + #[serde(default)] + pub return_text_in_logprobs: bool, + /// Whether to stream the response #[serde(default)] pub stream: bool, - /// Whether to return logprobs - #[serde(default)] - pub return_logprob: bool, - - /// Path to LoRA adapter(s) for model customization - #[serde(skip_serializing_if = "Option::is_none")] - pub lora_path: Option, - - /// Session parameters for continual prompting - #[serde(skip_serializing_if = "Option::is_none")] - pub session_params: Option>, + /// Whether to log metrics for this request (e.g. health_generate calls do not log metrics) + #[serde(default = "default_true")] + pub log_metrics: bool, /// Return model hidden states #[serde(default)] pub return_hidden_states: bool, - /// Request ID for tracking + /// The modalities of the image data [image, multi-images, video] + #[serde(skip_serializing_if = "Option::is_none")] + pub modalities: Option>, + + /// Session parameters for continual prompting + #[serde(skip_serializing_if = "Option::is_none")] + pub session_params: Option>, + + /// Path to LoRA adapter(s) for model customization + #[serde(skip_serializing_if = "Option::is_none")] + pub lora_path: Option, + + /// LoRA adapter ID (if pre-loaded) + #[serde(skip_serializing_if = "Option::is_none")] + pub lora_id: Option, + + /// Custom logit processor for advanced sampling control. Must be a serialized instance + /// of `CustomLogitProcessor` in python/sglang/srt/sampling/custom_logit_processor.py + /// Use the processor's `to_str()` method to generate the serialized string. + #[serde(skip_serializing_if = "Option::is_none")] + pub custom_logit_processor: Option, + + /// For disaggregated inference + #[serde(skip_serializing_if = "Option::is_none")] + pub bootstrap_host: Option, + + /// For disaggregated inference + #[serde(skip_serializing_if = "Option::is_none")] + pub bootstrap_port: Option, + + /// For disaggregated inference + #[serde(skip_serializing_if = "Option::is_none")] + pub bootstrap_room: Option, + + /// For disaggregated inference + #[serde(skip_serializing_if = "Option::is_none")] + pub bootstrap_pair_key: Option, + + /// Data parallel rank routing + #[serde(skip_serializing_if = "Option::is_none")] + pub data_parallel_rank: Option, + + /// Background response + #[serde(default)] + pub background: bool, + + /// Conversation ID for tracking + #[serde(skip_serializing_if = "Option::is_none")] + pub conversation_id: Option, + + /// Priority for the request + #[serde(skip_serializing_if = "Option::is_none")] + pub priority: Option, + + /// Extra key for classifying the request (e.g. cache_salt) + #[serde(skip_serializing_if = "Option::is_none")] + pub extra_key: Option, + + /// Whether to disallow logging for this request (e.g. due to ZDR) + #[serde(default)] + pub no_logs: bool, + + /// Custom metric labels + #[serde(skip_serializing_if = "Option::is_none")] + pub custom_labels: Option>, + + /// Whether to return bytes for image generation + #[serde(default)] + pub return_bytes: bool, + + /// Whether to return entropy + #[serde(default)] + pub return_entropy: bool, + + /// Request ID for tracking (inherited from BaseReq in Python) #[serde(skip_serializing_if = "Option::is_none")] pub rid: Option, } @@ -2358,7 +2467,7 @@ impl Normalizable for GenerateRequest { fn validate_generate_request(req: &GenerateRequest) -> Result<(), validator::ValidationError> { // Exactly one of text or input_ids must be provided // Note: input_embeds not yet supported in Rust implementation - let has_text = req.text.is_some() || req.prompt.is_some(); + let has_text = req.text.is_some(); let has_input_ids = req.input_ids.is_some(); let count = [has_text, has_input_ids].iter().filter(|&&x| x).count(); @@ -2389,18 +2498,11 @@ impl GenerationRequest for GenerateRequest { } fn extract_text_for_routing(&self) -> String { - // Check fields in priority order: text, prompt, inputs + // Check fields in priority order: text, input_ids if let Some(ref text) = self.text { return text.clone(); } - if let Some(ref prompt) = self.prompt { - return match prompt { - StringOrArray::String(s) => s.clone(), - StringOrArray::Array(v) => v.join(" "), - }; - } - if let Some(ref input_ids) = self.input_ids { return match input_ids { InputIds::Single(ids) => ids diff --git a/sgl-router/src/routers/grpc/pipeline.rs b/sgl-router/src/routers/grpc/pipeline.rs index 94f55789d..ec46757aa 100644 --- a/sgl-router/src/routers/grpc/pipeline.rs +++ b/sgl-router/src/routers/grpc/pipeline.rs @@ -877,7 +877,7 @@ impl ResponseProcessingStage { } // Non-streaming: Delegate to ResponseProcessor - let request_logprobs = ctx.generate_request().return_logprob; + let request_logprobs = ctx.generate_request().return_logprob.unwrap_or(false); let generate_request = ctx.generate_request_arc(); let stop_decoder = ctx diff --git a/sgl-router/src/routers/grpc/streaming.rs b/sgl-router/src/routers/grpc/streaming.rs index c53304095..d27d95f98 100644 --- a/sgl-router/src/routers/grpc/streaming.rs +++ b/sgl-router/src/routers/grpc/streaming.rs @@ -616,7 +616,7 @@ impl StreamingProcessor { generate_request: Arc, dispatch: context::DispatchMetadata, ) -> Response { - let return_logprob = generate_request.return_logprob; + let return_logprob = generate_request.return_logprob.unwrap_or(false); // Create SSE channel let (tx, rx) = mpsc::unbounded_channel::>(); diff --git a/sgl-router/src/routers/http/pd_router.rs b/sgl-router/src/routers/http/pd_router.rs index 77feb22ae..5a1a6fc2c 100644 --- a/sgl-router/src/routers/http/pd_router.rs +++ b/sgl-router/src/routers/http/pd_router.rs @@ -150,11 +150,6 @@ impl PDRouter { } fn get_generate_batch_size(req: &GenerateRequest) -> Option { - if let Some(StringOrArray::Array(arr)) = &req.prompt { - if !arr.is_empty() { - return Some(arr.len()); - } - } if let Some(text) = &req.text { if text.contains("[") && text.contains("]") { return None; @@ -1061,18 +1056,10 @@ impl RouterTrait for PDRouter { model_id: Option<&str>, ) -> Response { let is_stream = body.stream; - let return_logprob = body.return_logprob; + let return_logprob = body.return_logprob.unwrap_or(false); let request_text = if self.policies_need_request_text() { - body.text - .as_deref() - .or_else(|| { - body.prompt.as_ref().and_then(|p| match p { - StringOrArray::String(s) => Some(s.as_str()), - StringOrArray::Array(v) => v.first().map(|s| s.as_str()), - }) - }) - .map(|s| s.to_string()) + body.text.as_deref().map(|s| s.to_string()) } else { None }; diff --git a/sgl-router/tests/test_openai_routing.rs b/sgl-router/tests/test_openai_routing.rs index e864b4ec0..b38b90f7e 100644 --- a/sgl-router/tests/test_openai_routing.rs +++ b/sgl-router/tests/test_openai_routing.rs @@ -598,15 +598,39 @@ async fn test_unsupported_endpoints() { .unwrap(); let generate_request = GenerateRequest { - prompt: None, text: Some("Hello world".to_string()), input_ids: None, + input_embeds: None, + image_data: None, + video_data: None, + audio_data: None, sampling_params: None, stream: false, - return_logprob: false, - lora_path: None, - session_params: None, + return_logprob: Some(false), + logprob_start_len: None, + top_logprobs_num: None, + token_ids_logprob: None, + return_text_in_logprobs: false, + log_metrics: true, return_hidden_states: false, + modalities: None, + session_params: None, + lora_path: None, + lora_id: None, + custom_logit_processor: None, + bootstrap_host: None, + bootstrap_port: None, + bootstrap_room: None, + bootstrap_pair_key: None, + data_parallel_rank: None, + background: false, + conversation_id: None, + priority: None, + extra_key: None, + no_logs: false, + custom_labels: None, + return_bytes: false, + return_entropy: false, rid: None, };