From 01c9ee1ab44fd732af38c947b69350dbfc24a194 Mon Sep 17 00:00:00 2001 From: Simo Lin Date: Wed, 8 Oct 2025 12:38:50 -0400 Subject: [PATCH] [router] refactor generate to use new pipeline arch (#11323) --- sgl-router/src/protocols/spec.rs | 91 ++-- sgl-router/src/routers/grpc/context.rs | 33 +- sgl-router/src/routers/grpc/pd_router.rs | 637 +---------------------- sgl-router/src/routers/grpc/pipeline.rs | 232 +++++++-- sgl-router/src/routers/grpc/router.rs | 434 +-------------- sgl-router/src/routers/grpc/streaming.rs | 401 ++++++++++++-- sgl-router/src/routers/grpc/utils.rs | 66 ++- 7 files changed, 713 insertions(+), 1181 deletions(-) diff --git a/sgl-router/src/protocols/spec.rs b/sgl-router/src/protocols/spec.rs index dc60f9c2b..7e12a8cfe 100644 --- a/sgl-router/src/protocols/spec.rs +++ b/sgl-router/src/protocols/spec.rs @@ -2066,39 +2066,64 @@ impl GenerationRequest for GenerateRequest { } } -// TODO(generate): Define GenerateResponse and GenerateChoice structs -// -// Required for pipeline generate response processing (see grpc/pipeline.rs:931-964) -// -// #[derive(Debug, Clone, Serialize, Deserialize)] -// pub struct GenerateResponse { -// pub id: String, -// pub object: String, // "text.completion" -// pub created: u64, -// pub model: String, -// pub choices: Vec, -// #[serde(skip_serializing_if = "Option::is_none")] -// pub usage: Option, -// #[serde(skip_serializing_if = "Option::is_none")] -// pub system_fingerprint: Option, -// } -// -// #[derive(Debug, Clone, Serialize, Deserialize)] -// pub struct GenerateChoice { -// pub index: u32, -// pub text: String, -// #[serde(skip_serializing_if = "Option::is_none")] -// pub output_ids: Option>, -// #[serde(skip_serializing_if = "Option::is_none")] -// pub finish_reason: Option, -// #[serde(skip_serializing_if = "Option::is_none")] -// pub logprobs: Option, -// #[serde(skip_serializing_if = "Option::is_none")] -// pub matched_stop: Option, -// } -// -// Note: Verify if similar structs already exist elsewhere before implementing. -// May need streaming variant (GenerateStreamResponse) as well. +// ============================================================================ +// SGLang Generate Response Types +// ============================================================================ + +/// SGLang generate response (single completion or array for n>1) +/// +/// Format for n=1: +/// ```json +/// { +/// "text": "...", +/// "output_ids": [...], +/// "meta_info": { ... } +/// } +/// ``` +/// +/// Format for n>1: +/// ```json +/// [ +/// {"text": "...", "output_ids": [...], "meta_info": {...}}, +/// {"text": "...", "output_ids": [...], "meta_info": {...}} +/// ] +/// ``` +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct GenerateResponse { + pub text: String, + pub output_ids: Vec, + pub meta_info: GenerateMetaInfo, +} + +/// Metadata for a single generate completion +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct GenerateMetaInfo { + pub id: String, + pub finish_reason: GenerateFinishReason, + pub prompt_tokens: u32, + pub weight_version: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub input_token_logprobs: Option>>>, + #[serde(skip_serializing_if = "Option::is_none")] + pub output_token_logprobs: Option>>>, + pub completion_tokens: u32, + pub cached_tokens: u32, + pub e2e_latency: f64, + #[serde(skip_serializing_if = "Option::is_none")] + pub matched_stop: Option, +} + +/// Finish reason for generate endpoint +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "type", rename_all = "lowercase")] +pub enum GenerateFinishReason { + Length { + length: u32, + }, + Stop, + #[serde(untagged)] + Other(Value), +} // Constants for rerank API pub const DEFAULT_MODEL_NAME: &str = "default"; diff --git a/sgl-router/src/routers/grpc/context.rs b/sgl-router/src/routers/grpc/context.rs index f6bd73462..5c6bb7a99 100644 --- a/sgl-router/src/routers/grpc/context.rs +++ b/sgl-router/src/routers/grpc/context.rs @@ -12,7 +12,9 @@ use serde_json::Value; use crate::core::Worker; use crate::grpc_client::{proto, SglangSchedulerClient}; -use crate::protocols::spec::{ChatCompletionRequest, ChatCompletionResponse, GenerateRequest}; +use crate::protocols::spec::{ + ChatCompletionRequest, ChatCompletionResponse, GenerateRequest, GenerateResponse, +}; use crate::reasoning_parser::ReasoningParserFactory; use crate::tokenizer::stop::StopSequenceDecoder; use crate::tokenizer::traits::Tokenizer; @@ -226,14 +228,6 @@ impl RequestContext { } } - /// Try to get chat request - pub fn try_chat_request(&self) -> Option<&ChatCompletionRequest> { - match &self.input.request_type { - RequestType::Chat(req) => Some(req.as_ref()), - _ => None, - } - } - /// Get generate request (panics if not generate) pub fn generate_request(&self) -> &GenerateRequest { match &self.input.request_type { @@ -242,14 +236,6 @@ impl RequestContext { } } - /// Try to get generate request - pub fn try_generate_request(&self) -> Option<&GenerateRequest> { - match &self.input.request_type { - RequestType::Generate(req) => Some(req.as_ref()), - _ => None, - } - } - /// Check if request is streaming pub fn is_streaming(&self) -> bool { match &self.input.request_type { @@ -257,16 +243,6 @@ impl RequestContext { RequestType::Generate(req) => req.stream, } } - - /// Check if request is chat - pub fn is_chat(&self) -> bool { - matches!(&self.input.request_type, RequestType::Chat(_)) - } - - /// Check if request is generate - pub fn is_generate(&self) -> bool { - matches!(&self.input.request_type, RequestType::Generate(_)) - } } // ============================================================================ @@ -394,5 +370,6 @@ pub enum ExecutionResult { /// Final processed response pub enum FinalResponse { Chat(ChatCompletionResponse), - Generate(Box), + /// Generate response is a Vec of GenerateResponse (n=1 returns single item, n>1 returns multiple) + Generate(Vec), } diff --git a/sgl-router/src/routers/grpc/pd_router.rs b/sgl-router/src/routers/grpc/pd_router.rs index daad4b9d8..a387c16b6 100644 --- a/sgl-router/src/routers/grpc/pd_router.rs +++ b/sgl-router/src/routers/grpc/pd_router.rs @@ -1,40 +1,27 @@ // PD (Prefill-Decode) gRPC Router Implementation use crate::config::types::RetryConfig; -use crate::core::{ConnectionMode, Worker, WorkerRegistry, WorkerType}; -use crate::grpc_client::proto; -use crate::grpc_client::SglangSchedulerClient; +use crate::core::{ConnectionMode, WorkerRegistry, WorkerType}; use crate::policies::PolicyRegistry; use crate::protocols::spec::{ - ChatCompletionRequest, CompletionRequest, EmbeddingRequest, GenerateRequest, InputIds, - RerankRequest, ResponsesGetParams, ResponsesRequest, + ChatCompletionRequest, CompletionRequest, EmbeddingRequest, GenerateRequest, RerankRequest, + ResponsesGetParams, ResponsesRequest, }; use crate::reasoning_parser::ReasoningParserFactory; -use crate::routers::http::pd_types::generate_room_id; -use crate::routers::{grpc, RouterTrait}; +use crate::routers::RouterTrait; use crate::server::AppContext; use crate::tokenizer::traits::Tokenizer; -use crate::tokenizer::SequenceDecoderOutput; use crate::tool_parser::ToolParserFactory; use async_trait::async_trait; use axum::{ body::Body, extract::Request, - http::{header, HeaderMap, HeaderValue, StatusCode}, + http::{HeaderMap, StatusCode}, response::{IntoResponse, Response}, - Json, }; -use grpc::utils; -use proto::generate_response::Response::{Chunk, Complete, Error}; -use std::collections::HashMap; use std::sync::Arc; -use std::time::Instant; -use tokio::sync::mpsc::unbounded_channel; -use tokio::sync::mpsc::UnboundedSender; -use tokio_stream::Stream; -use tokio_stream::StreamExt; -use tracing::{debug, error}; -use uuid::Uuid; + +use tracing::debug; /// gRPC PD (Prefill-Decode) router implementation for SGLang #[derive(Clone)] @@ -50,9 +37,7 @@ pub struct GrpcPDRouter { retry_config: RetryConfig, configured_reasoning_parser: Option, configured_tool_parser: Option, - // Pipeline for non-streaming requests pipeline: super::pipeline::ChatCompletionPipeline, - // Shared components for pipeline shared_components: Arc, } @@ -129,93 +114,10 @@ impl GrpcPDRouter { }) } - /// Select a prefill-decode worker pair using load balancing policies - async fn select_pd_pair( - &self, - request_text: Option<&str>, - model_id: Option<&str>, - ) -> Result<(Arc, Arc), String> { - let effective_model_id = if !self.dp_aware { None } else { model_id }; - - debug!( - "Selecting PD pair: dp_aware={}, model_id={:?}, effective_model_id={:?}", - self.dp_aware, model_id, effective_model_id - ); - - // Get prefill workers - let prefill_workers = if let Some(model) = effective_model_id { - self.worker_registry - .get_by_model_fast(model) - .into_iter() - .filter(|w| matches!(w.worker_type(), WorkerType::Prefill { .. })) - .collect() - } else { - self.worker_registry.get_workers_filtered( - None, - Some(WorkerType::Prefill { - bootstrap_port: None, - }), - Some(ConnectionMode::Grpc { port: None }), - true, // only healthy workers - ) - }; - - // Get decode workers - let decode_workers = if let Some(model) = effective_model_id { - self.worker_registry - .get_by_model_fast(model) - .into_iter() - .filter(|w| matches!(w.worker_type(), WorkerType::Decode)) - .collect() - } else { - self.worker_registry.get_workers_filtered( - None, - Some(WorkerType::Decode), - Some(ConnectionMode::Grpc { port: None }), - true, // only healthy workers - ) - }; - - if prefill_workers.is_empty() { - return Err("No healthy prefill workers available".to_string()); - } - if decode_workers.is_empty() { - return Err("No healthy decode workers available".to_string()); - } - - debug!( - "Found {} prefill workers and {} decode workers", - prefill_workers.len(), - decode_workers.len() - ); - - let prefill_policy = self.policy_registry.get_prefill_policy(); - let decode_policy = self.policy_registry.get_decode_policy(); - - let prefill_idx = prefill_policy - .select_worker(&prefill_workers, request_text) - .ok_or_else(|| "Failed to select prefill worker".to_string())?; - - let decode_idx = decode_policy - .select_worker(&decode_workers, request_text) - .ok_or_else(|| "Failed to select decode worker".to_string())?; - - let prefill = prefill_workers[prefill_idx].clone(); - let decode = decode_workers[decode_idx].clone(); - - debug!( - "Selected PD pair: prefill={}, decode={}", - prefill.url(), - decode.url() - ); - - Ok((prefill, decode)) - } - /// Main route_generate implementation with PD dual dispatch async fn route_generate_impl( &self, - _headers: Option<&HeaderMap>, + headers: Option<&HeaderMap>, body: &GenerateRequest, model_id: Option<&str>, ) -> Response { @@ -224,125 +126,15 @@ impl GrpcPDRouter { model_id ); - // Step 1: Resolve input (text or input_ids) - let (original_text, token_ids) = match self.resolve_generate_input(body) { - Ok(res) => res, - Err(msg) => { - return utils::bad_request_error(msg); - } - }; - - debug!("Resolved input with {} tokens", token_ids.len()); - - // Step 2: Select prefill-decode worker pair - let (prefill_worker, decode_worker) = match self - .select_pd_pair(original_text.as_deref(), model_id) - .await - { - Ok(pair) => pair, - Err(e) => { - return utils::service_unavailable_error(e); - } - }; - - debug!( - "Selected PD pair: prefill={}, decode={}", - prefill_worker.url(), - decode_worker.url() - ); - - // Step 3: Get gRPC clients for both workers - let prefill_client = match utils::get_grpc_client_from_worker(&prefill_worker).await { - Ok(client) => client, - Err(response) => return response, - }; - - let decode_client = match utils::get_grpc_client_from_worker(&decode_worker).await { - Ok(client) => client, - Err(response) => return response, - }; - - // Step 4: Build the gRPC request - let request_id = body - .rid - .clone() - .unwrap_or_else(|| format!("gen-{}", Uuid::new_v4())); - - let mut request = match prefill_client.build_plain_generate_request( - request_id.clone(), - body, - original_text.clone(), - token_ids, - ) { - Ok(req) => req, - Err(e) => { - return utils::bad_request_error(e); - } - }; - - // Step 5: Inject bootstrap metadata - if let Err(e) = Self::inject_bootstrap_metadata(&mut request, &*prefill_worker) { - return utils::internal_error_message(e); - } - - // Step 6: Get weight version for response metadata - let weight_version = decode_worker - .metadata() - .labels - .get("weight_version") - .cloned() - .unwrap_or_else(|| "default".to_string()); - - // Step 7: Handle streaming vs non-streaming - if body.stream { - self.handle_streaming_generate( - prefill_client, - decode_client, - request, - body, - request_id, - weight_version, + // Use pipeline for ALL requests (streaming and non-streaming) + self.pipeline + .execute_generate( + body.clone(), + headers.cloned(), + model_id.map(|s| s.to_string()), + self.shared_components.clone(), ) .await - } else { - self.handle_non_streaming_generate( - prefill_client, - decode_client, - request, - body, - request_id, - weight_version, - ) - .await - } - } - - /// Inject bootstrap metadata into a protobuf GenerateRequest - fn inject_bootstrap_metadata( - request: &mut proto::GenerateRequest, - prefill_worker: &dyn Worker, - ) -> Result<(), String> { - let hostname = prefill_worker.bootstrap_host(); - let bootstrap_port = prefill_worker.bootstrap_port().unwrap_or(8998); - - let room_id = generate_room_id(); - - // Create DisaggregatedParams - let disagg_params = proto::DisaggregatedParams { - bootstrap_host: hostname.to_string(), - bootstrap_port: bootstrap_port as i32, - bootstrap_room: room_id as i32, - }; - - // Inject metadata - request.disaggregated_params = Some(disagg_params); - - debug!( - "Injected bootstrap metadata: host={}, port={}, room={}", - hostname, bootstrap_port, room_id - ); - - Ok(()) } /// Main route_chat implementation with PD dual dispatch @@ -367,405 +159,6 @@ impl GrpcPDRouter { ) .await } - - /// Resolve the generate input into optional original text and token IDs - fn resolve_generate_input( - &self, - request: &GenerateRequest, - ) -> Result<(Option, Vec), String> { - if let Some(text) = &request.text { - let encoding = self - .tokenizer - .encode(text) - .map_err(|e| format!("Tokenization failed: {}", e))?; - return Ok((Some(text.to_string()), encoding.token_ids().to_vec())); - } - - // Handle input_ids - validate and convert - if let Some(input_ids) = &request.input_ids { - return match input_ids { - InputIds::Single(ids) => ids - .iter() - .map(|&id| u32::try_from(id)) - .collect::, _>>() - .map(|converted| (None, converted)) - .map_err(|_| "input_ids must be non-negative".to_string()), - InputIds::Batch(_) => { - Err("Batch input_ids are not supported in PD mode".to_string()) - } - }; - } - - Err("Either `text` or `input_ids` must be provided".to_string()) - } - - /// Submit request and handle streaming response for generate endpoint (PD mode) - async fn handle_streaming_generate( - &self, - mut prefill_client: SglangSchedulerClient, - mut decode_client: SglangSchedulerClient, - request: proto::GenerateRequest, - original_request: &GenerateRequest, - request_id: String, - weight_version: String, - ) -> Response { - // Create channel for SSE streaming - let (tx, rx) = unbounded_channel::>(); - - // Send requests in parallel to both prefill and decode workers - debug!("Starting concurrent streaming generate requests to prefill and decode workers"); - let prefill_request = request.clone(); - let decode_request = request; - - let (prefill_result, decode_result) = tokio::join!( - prefill_client.generate(prefill_request), - decode_client.generate(decode_request) - ); - - // Get prefill stream (for input_logprobs if needed) - let prefill_stream = match prefill_result { - Ok(s) => s, - Err(e) => { - return utils::internal_error_message(format!( - "Prefill worker failed to start: {}", - e - )); - } - }; - - // Get decode stream - this is what we'll process for output - let decode_stream = match decode_result { - Ok(s) => s, - Err(e) => { - return utils::internal_error_message(format!( - "Decode worker failed to start: {}", - e - )); - } - }; - - // Spawn processing task for both streams - let tokenizer = self.tokenizer.clone(); - let return_logprob = original_request.return_logprob; - tokio::spawn(async move { - let result = Self::process_generate_streaming( - tokenizer, - prefill_stream, - decode_stream, - request_id, - weight_version, - return_logprob, - &tx, - ) - .await; - - if let Err(e) = result { - let error_chunk = format!( - "data: {}\n\n", - serde_json::json!({ - "error": { - "message": e, - "type": "internal_error" - } - }) - ); - let _ = tx.send(Ok(bytes::Bytes::from(error_chunk))); - } - - // Send DONE marker - let _ = tx.send(Ok(bytes::Bytes::from("data: [DONE]\n\n"))); - }); - - // Create response with SSE headers - let stream = tokio_stream::wrappers::UnboundedReceiverStream::new(rx); - let mut response = Response::new(Body::from_stream(stream)); - *response.status_mut() = StatusCode::OK; - response.headers_mut().insert( - header::CONTENT_TYPE, - HeaderValue::from_static("text/event-stream"), - ); - response - .headers_mut() - .insert("Cache-Control", HeaderValue::from_static("no-cache")); - response - .headers_mut() - .insert("Connection", HeaderValue::from_static("keep-alive")); - response - } - - /// Process generate streaming (simplified - no tool calls or reasoning) - #[allow(clippy::too_many_arguments)] - async fn process_generate_streaming( - tokenizer: Arc, - mut prefill_stream: impl Stream> + Unpin, - mut decode_stream: impl Stream> + Unpin, - request_id: String, - weight_version: String, - include_logprobs: bool, - tx: &UnboundedSender>, - ) -> Result<(), String> { - let start_time = Instant::now(); - - // Phase 1: Collect input_logprobs from prefill stream if requested - // TODO: Store and emit input_logprobs when implementing prompt logprobs in streaming - if include_logprobs { - while let Some(response) = prefill_stream.next().await { - let gen_response = response.map_err(|e| format!("Prefill stream error: {}", e))?; - match gen_response.response { - Some(Complete(_complete)) => { - // Input logprobs collected but not yet used in streaming - break; - } - Some(Error(error)) => { - return Err(format!("Prefill error: {}", error.message)); - } - _ => continue, - } - } - } - - // Phase 2: Main streaming loop (decode stream) - // Track state per index for n>1 case - let mut accumulated_texts: HashMap = HashMap::new(); - let mut completion_tokens_map: HashMap = HashMap::new(); - let mut current_index: u32 = 0; - - while let Some(response) = decode_stream.next().await { - let gen_response = response.map_err(|e| format!("Decode stream error: {}", e))?; - - match gen_response.response { - Some(Chunk(chunk)) => { - // Use our tracked index instead of chunk.index (PD backend bug workaround) - let index = current_index; - debug!( - "Received chunk with backend_index={}, using_index={}, tokens={:?}", - chunk.index, index, chunk.token_ids - ); - - let completion_tokens = completion_tokens_map.entry(index).or_insert(0); - *completion_tokens += chunk.token_ids.len() as u32; - - let chunk_text = tokenizer.decode(&chunk.token_ids, true).unwrap_or_default(); - - let accumulated_text = accumulated_texts.entry(index).or_default(); - accumulated_text.push_str(&chunk_text); - - let index_id = format!("{}-{}", request_id, index); - - let chunk_response = serde_json::json!({ - "text": accumulated_text.clone(), - "output_ids": chunk.token_ids, - "meta_info": { - "id": index_id, - "finish_reason": null, - "prompt_tokens": chunk.prompt_tokens, - "weight_version": weight_version, - "completion_tokens": *completion_tokens, - "cached_tokens": chunk.cached_tokens - }, - "index": index - }); - - let sse_chunk = format!( - "data: {}\n\n", - serde_json::to_string(&chunk_response).unwrap() - ); - tx.send(Ok(bytes::Bytes::from(sse_chunk))) - .map_err(|_| "Failed to send chunk".to_string())?; - } - Some(Complete(complete)) => { - let index = current_index; - debug!( - "Received Complete with backend_index={}, using_index={}, finish_reason={}", - complete.index, index, complete.finish_reason - ); - let accumulated_text = - accumulated_texts.get(&index).cloned().unwrap_or_default(); - let completion_tokens = *completion_tokens_map.get(&index).unwrap_or(&0); - let index_id = format!("{}-{}", request_id, index); - let e2e_latency = start_time.elapsed().as_secs_f64(); - - // Send final chunk with finish_reason (no new tokens in Complete, they were already sent in Chunks) - let finish_response = serde_json::json!({ - "text": accumulated_text, - "output_ids": complete.output_ids[complete.output_ids.len().saturating_sub(1)..].to_vec(), - "meta_info": { - "id": index_id, - "finish_reason": complete.finish_reason, - "prompt_tokens": complete.prompt_tokens, - "weight_version": weight_version, - "completion_tokens": completion_tokens, - "cached_tokens": complete.cached_tokens, - "e2e_latency": e2e_latency - }, - "index": index - }); - - let sse_chunk = format!( - "data: {}\n\n", - serde_json::to_string(&finish_response).unwrap() - ); - tx.send(Ok(bytes::Bytes::from(sse_chunk))) - .map_err(|_| "Failed to send finish chunk".to_string())?; - - // Move to next completion - current_index += 1; - } - Some(Error(error)) => { - return Err(error.message); - } - None => continue, - } - } - - Ok(()) - } - - /// Submit request and handle non-streaming response for generate endpoint (PD mode) - async fn handle_non_streaming_generate( - &self, - mut prefill_client: SglangSchedulerClient, - mut decode_client: SglangSchedulerClient, - request: proto::GenerateRequest, - original_request: &GenerateRequest, - request_id: String, - weight_version: String, - ) -> Response { - use std::time::Instant; - - let start_time = Instant::now(); - - // Send requests in parallel - debug!("Sending concurrent generate requests to prefill and decode workers"); - let prefill_request = request.clone(); - let decode_request = request; - - let (prefill_result, decode_result) = tokio::join!( - prefill_client.generate(prefill_request), - decode_client.generate(decode_request) - ); - - // Process prefill stream - let prefill_stream = match prefill_result { - Ok(s) => s, - Err(e) => { - error!("Failed to start prefill generation: {}", e); - return utils::internal_error_message(format!( - "Prefill worker failed to start: {}", - e - )); - } - }; - - let decode_stream = match decode_result { - Ok(s) => s, - Err(e) => { - error!("Failed to start decode generation: {}", e); - return utils::internal_error_message(format!( - "Decode worker failed to start: {}", - e - )); - } - }; - - // Collect prefill responses - // TODO add logprob for generate - let _prefill_responses = - match utils::collect_stream_responses(prefill_stream, "Prefill").await { - Ok(responses) => responses, - Err(error_response) => return error_response, - }; - - // Collect decode responses - let decode_responses = match utils::collect_stream_responses(decode_stream, "Decode").await - { - Ok(responses) => responses, - Err(error_response) => return error_response, - }; - - if decode_responses.is_empty() { - return utils::internal_error_static("No completion received from decode worker"); - } - - // Create stop decoder from sampling params - let params = original_request.sampling_params.as_ref(); - let mut stop_decoder = utils::create_stop_decoder( - &self.tokenizer, - params.and_then(|p| p.stop.as_ref()), - params.and_then(|p| p.stop_token_ids.as_ref()), - params.and_then(|p| p.skip_special_tokens).unwrap_or(true), - params.and_then(|p| p.no_stop_trim).unwrap_or(false), - ); - - // Process each completion - let mut result_array = Vec::new(); - for mut complete in decode_responses { - stop_decoder.reset(); - - // Process tokens through stop decoder - let outputs = match stop_decoder.process_tokens(&complete.output_ids) { - Ok(outputs) => outputs, - Err(e) => { - return utils::internal_error_message(format!( - "Failed to process tokens: {}", - e - )) - } - }; - - // Accumulate text with early breaks - let mut decoded_text = String::new(); - for output in outputs { - match output { - SequenceDecoderOutput::Text(t) => decoded_text.push_str(&t), - SequenceDecoderOutput::StoppedWithText(t) => { - decoded_text.push_str(&t); - break; - } - SequenceDecoderOutput::Stopped => break, - SequenceDecoderOutput::Held => {} - } - } - - // Flush remaining text - if let SequenceDecoderOutput::Text(t) = stop_decoder.flush() { - decoded_text.push_str(&t); - } - - let output_ids = complete.output_ids.clone(); - - // Build base meta_info - let mut meta_info = serde_json::json!({ - "id": request_id.clone(), - "finish_reason": complete.finish_reason.clone(), - "prompt_tokens": complete.prompt_tokens, - "weight_version": weight_version.clone(), - "completion_tokens": complete.completion_tokens, - "cached_tokens": complete.cached_tokens, - "e2e_latency": start_time.elapsed().as_secs_f64(), - }); - - let meta_obj = meta_info.as_object_mut().unwrap(); - - // Add matched_stop if present - if let Some(matched) = complete.matched_stop.take() { - use proto::generate_complete::MatchedStop; - let matched_value = match matched { - MatchedStop::MatchedTokenId(id) => serde_json::json!(id), - MatchedStop::MatchedStopStr(s) => serde_json::json!(s), - }; - meta_obj.insert("matched_stop".to_string(), matched_value); - } - - result_array.push(serde_json::json!({ - "text": decoded_text, - "output_ids": output_ids, - "meta_info": meta_info, - })); - } - - Json(result_array).into_response() - } } impl std::fmt::Debug for GrpcPDRouter { diff --git a/sgl-router/src/routers/grpc/pipeline.rs b/sgl-router/src/routers/grpc/pipeline.rs index 380be569e..01c97df63 100644 --- a/sgl-router/src/routers/grpc/pipeline.rs +++ b/sgl-router/src/routers/grpc/pipeline.rs @@ -11,15 +11,20 @@ use super::context::*; use super::processing; use super::streaming; use super::utils; -use crate::core::{ConnectionMode, WorkerRegistry, WorkerType}; +use crate::core::{ConnectionMode, Worker, WorkerRegistry, WorkerType}; use crate::grpc_client::proto; use crate::policies::PolicyRegistry; use crate::protocols::spec::{ - ChatCompletionRequest, ChatCompletionResponse, GenerateRequest, InputIds, Usage, + ChatCompletionRequest, ChatCompletionResponse, GenerateMetaInfo, GenerateRequest, + GenerateResponse, InputIds, Usage, }; +use crate::tokenizer::stop::SequenceDecoderOutput; +use crate::tokenizer::traits::Tokenizer; +use proto::generate_complete::MatchedStop; +use proto::DisaggregatedParams; use rand::Rng; use std::sync::Arc; -use std::time::{SystemTime, UNIX_EPOCH}; +use std::time::{Instant, SystemTime, UNIX_EPOCH}; use uuid::Uuid; // ============================================================================ @@ -208,7 +213,7 @@ impl PreparationStage { fn tokenize_single_text( &self, - tokenizer: &Arc, + tokenizer: &Arc, text: &str, ) -> Result<(String, Vec), String> { let encoding = tokenizer @@ -302,7 +307,7 @@ impl WorkerSelectionStage { &self, model_id: Option<&str>, text: Option<&str>, - ) -> Option> { + ) -> Option> { // Get workers for the specified model, filtered by connection mode let workers = self.worker_registry.get_workers_filtered( model_id, @@ -312,7 +317,7 @@ impl WorkerSelectionStage { ); // Filter by availability (health + circuit breaker) - let available: Vec> = workers + let available: Vec> = workers .iter() .filter(|w| w.is_available()) .cloned() @@ -337,7 +342,7 @@ impl WorkerSelectionStage { &self, model_id: Option<&str>, text: Option<&str>, - ) -> Option<(Arc, Arc)> { + ) -> Option<(Arc, Arc)> { // Get prefill workers - use None for WorkerType filter to get all types, // then filter manually (since Prefill is a struct variant) let all_workers = self.worker_registry.get_workers_filtered( @@ -537,10 +542,8 @@ impl RequestBuildingStage { fn inject_bootstrap_metadata( &self, request: &mut proto::GenerateRequest, - prefill_worker: &Arc, + prefill_worker: &Arc, ) { - use proto::DisaggregatedParams; - let hostname = prefill_worker.bootstrap_host(); let bootstrap_port = prefill_worker.bootstrap_port().unwrap_or(8998); @@ -935,40 +938,183 @@ impl ResponseProcessingStage { async fn process_generate_response( &self, - _ctx: &mut RequestContext, + ctx: &mut RequestContext, ) -> Result, Response> { - // TODO(generate): Implement generate response processing - // - // Required implementation: - // 1. Extract execution_result from ctx - // 2. Check is_streaming flag - // 3. For streaming: - // - Add StreamingProcessor::process_streaming_generate() method - // - Similar to process_streaming_response but WITHOUT tool/reasoning parsing - // - Return Err(sse_response) for early exit - // 4. For non-streaming: - // - Collect stream responses using utils::collect_stream_responses() - // - Process through stop decoder (sequential with reset for n>1, like chat) - // - Build GenerateResponse struct (see TODO in protocols/spec.rs) - // - Set ctx.state.response.final_response = Some(FinalResponse::Generate(response)) - // - // Reference implementation: router.rs:297-595 - // Key differences from chat: - // - No tool parsing - // - No reasoning parsing - // - Different response format (GenerateResponse instead of ChatCompletionResponse) - // - Still needs: stop decoder, logprobs, finish_reason, matched_stop - Err(( - axum::http::StatusCode::NOT_IMPLEMENTED, - axum::Json(serde_json::json!({ - "error": { - "message": "Generate response processing not yet implemented in pipeline", - "type": "not_implemented", - "code": 501 + let start_time = Instant::now(); + let is_streaming = ctx.is_streaming(); + + // Extract execution result + let execution_result = ctx + .state + .response + .execution_result + .take() + .ok_or_else(|| utils::internal_error_static("No execution result"))?; + + if is_streaming { + // Get dispatch metadata for consistent response fields + let dispatch = ctx + .state + .dispatch + .as_ref() + .ok_or_else(|| utils::internal_error_static("Dispatch metadata not set"))?; + + let generate_request = ctx.generate_request().clone(); + + // Streaming: Use StreamingProcessor and return SSE response (done) + return Ok(Some( + self.streaming_processor.clone().process_streaming_generate( + execution_result, + generate_request, + dispatch.clone(), + ), + )); + } + + // Non-streaming: Collect all responses + let request_logprobs = ctx.generate_request().return_logprob; + let all_responses = match execution_result { + ExecutionResult::Single { stream } => { + utils::collect_stream_responses(stream, "Single").await? + } + ExecutionResult::Dual { prefill, decode } => { + // Collect prefill for input_logprobs + let prefill_responses = utils::collect_stream_responses(prefill, "Prefill").await?; + + // Collect decode for actual output + let mut decode_responses = + utils::collect_stream_responses(*decode, "Decode").await?; + + // Merge prefill input_logprobs if requested + if request_logprobs { + if let Some(prefill_input_logprobs) = prefill_responses + .first() + .and_then(|r| r.input_logprobs.clone()) + { + for response in &mut decode_responses { + response.input_logprobs = Some(prefill_input_logprobs.clone()); + } + } } - })), - ) - .into_response()) + + decode_responses + } + }; + + if all_responses.is_empty() { + return Err(utils::internal_error_static("No responses from server")); + } + + // Get stop decoder for processing + let stop_decoder = ctx + .state + .response + .stop_decoder + .as_mut() + .ok_or_else(|| utils::internal_error_static("Stop decoder not initialized"))?; + + // Get dispatch metadata + let dispatch = ctx + .state + .dispatch + .as_ref() + .ok_or_else(|| utils::internal_error_static("Dispatch metadata not set"))?; + + // Process each completion (similar to router.rs:336-400) + let mut result_array = Vec::new(); + for mut complete in all_responses { + stop_decoder.reset(); + + // Process tokens through stop decoder + let outputs = match stop_decoder.process_tokens(&complete.output_ids) { + Ok(outputs) => outputs, + Err(e) => { + return Err(utils::internal_error_message(format!( + "Failed to process tokens: {}", + e + ))) + } + }; + + // Accumulate text with early breaks + let mut decoded_text = String::new(); + for output in outputs { + match output { + SequenceDecoderOutput::Text(t) => decoded_text.push_str(&t), + SequenceDecoderOutput::StoppedWithText(t) => { + decoded_text.push_str(&t); + break; + } + SequenceDecoderOutput::Stopped => break, + SequenceDecoderOutput::Held => {} + } + } + + // Flush remaining text + if let SequenceDecoderOutput::Text(t) = stop_decoder.flush() { + decoded_text.push_str(&t); + } + + let output_ids = std::mem::take(&mut complete.output_ids); + let finish_reason_str = std::mem::take(&mut complete.finish_reason); + + // Parse finish_reason from string to proper type + let finish_reason = + utils::parse_finish_reason(&finish_reason_str, complete.completion_tokens); + + // Handle matched_stop if present + let matched_stop = complete.matched_stop.take().map(|matched| match matched { + MatchedStop::MatchedTokenId(id) => serde_json::json!(id), + MatchedStop::MatchedStopStr(s) => serde_json::json!(s), + }); + + // Extract logprobs if requested (convert proto types to Generate format) + let input_token_logprobs = if request_logprobs { + complete + .input_logprobs + .as_ref() + .map(utils::convert_generate_input_logprobs) + } else { + None + }; + + let output_token_logprobs = if request_logprobs { + complete + .output_logprobs + .as_ref() + .map(utils::convert_generate_output_logprobs) + } else { + None + }; + + // Build GenerateResponse struct + let meta_info = GenerateMetaInfo { + id: dispatch.request_id.clone(), + finish_reason, + prompt_tokens: complete.prompt_tokens as u32, + weight_version: dispatch + .weight_version + .clone() + .unwrap_or_else(|| "default".to_string()), + input_token_logprobs, + output_token_logprobs, + completion_tokens: complete.completion_tokens as u32, + cached_tokens: complete.cached_tokens as u32, + e2e_latency: start_time.elapsed().as_secs_f64(), + matched_stop, + }; + + result_array.push(GenerateResponse { + text: decoded_text, + output_ids, + meta_info, + }); + } + + // Store the final response + ctx.state.response.final_response = Some(FinalResponse::Generate(result_array)); + + Ok(None) } } @@ -1136,7 +1282,7 @@ impl ChatCompletionPipeline { // Extract final response match ctx.state.response.final_response { - Some(FinalResponse::Generate(response)) => axum::Json(*response).into_response(), + Some(FinalResponse::Generate(response)) => axum::Json(response).into_response(), Some(FinalResponse::Chat(_)) => { utils::internal_error_static("Internal error: wrong response type") } diff --git a/sgl-router/src/routers/grpc/router.rs b/sgl-router/src/routers/grpc/router.rs index 6ccd18da6..9e95e49f4 100644 --- a/sgl-router/src/routers/grpc/router.rs +++ b/sgl-router/src/routers/grpc/router.rs @@ -8,28 +8,21 @@ use axum::{ extract::Request, http::{HeaderMap, StatusCode}, response::{IntoResponse, Response}, - Json, }; use tracing::debug; use crate::config::types::RetryConfig; -use crate::core::{ConnectionMode, Worker, WorkerRegistry, WorkerType}; -use crate::grpc_client::{proto, SglangSchedulerClient}; +use crate::core::WorkerRegistry; use crate::policies::PolicyRegistry; use crate::protocols::spec::{ - ChatCompletionRequest, CompletionRequest, EmbeddingRequest, GenerateRequest, InputIds, - RerankRequest, ResponsesGetParams, ResponsesRequest, + ChatCompletionRequest, CompletionRequest, EmbeddingRequest, GenerateRequest, RerankRequest, + ResponsesGetParams, ResponsesRequest, }; use crate::reasoning_parser::ReasoningParserFactory; -use crate::routers::{grpc, RouterTrait}; +use crate::routers::RouterTrait; use crate::server::AppContext; -use crate::tokenizer::stop::SequenceDecoderOutput; use crate::tokenizer::traits::Tokenizer; use crate::tool_parser::ToolParserFactory; -use grpc::utils; -use serde_json::json; -use std::time::Instant; -use uuid::Uuid; /// gRPC router implementation for SGLang #[derive(Clone)] @@ -45,9 +38,7 @@ pub struct GrpcRouter { retry_config: RetryConfig, configured_reasoning_parser: Option, configured_tool_parser: Option, - // Pipeline for non-streaming requests pipeline: super::pipeline::ChatCompletionPipeline, - // Shared components for pipeline shared_components: Arc, } @@ -149,420 +140,21 @@ impl GrpcRouter { /// Main route_generate implementation async fn route_generate_impl( &self, - _headers: Option<&HeaderMap>, + headers: Option<&HeaderMap>, body: &GenerateRequest, model_id: Option<&str>, ) -> Response { debug!("Processing generate request for model: {:?}", model_id); - // Step 1: Resolve input (text, prompt, or input_ids) - let (original_text, token_ids) = match self.resolve_generate_input(body) { - Ok(res) => res, - Err(msg) => { - return utils::bad_request_error(msg); - } - }; - - debug!("Resolved input with {} tokens", token_ids.len()); - - // Step 2: Select worker (fail fast if no workers available) - let worker = match self.select_worker_for_request(model_id, original_text.as_deref()) { - Some(w) => w, - None => { - return utils::service_unavailable_error(format!( - "No available workers for model: {:?}", - model_id - )); - } - }; - - debug!("Selected worker: {}", worker.url()); - - // Step 3: Get gRPC client from worker - let client = match utils::get_grpc_client_from_worker(&worker).await { - Ok(client) => client, - Err(response) => return response, - }; - - // Step 4: Build the gRPC request - let request_id = body - .rid - .clone() - .unwrap_or_else(|| format!("gen-{}", Uuid::new_v4())); - - let request = match client.build_plain_generate_request( - request_id.clone(), - body, - original_text.clone(), - token_ids, - ) { - Ok(req) => req, - Err(e) => { - return utils::bad_request_error(e); - } - }; - - // Step 5: Get weight version for response metadata - let weight_version = worker - .metadata() - .labels - .get("weight_version") - .cloned() - .unwrap_or_else(|| "default".to_string()); - - // Step 6: Handle streaming vs non-streaming - if body.stream { - self.handle_streaming_generate(client, request, body, request_id, weight_version) - .await - } else { - self.handle_non_streaming_generate(client, request, body, request_id, weight_version) - .await - } - } - - /// Select a worker for the request - fn select_worker_for_request( - &self, - model_id: Option<&str>, - text: Option<&str>, - ) -> Option> { - // Get workers for the specified model, filtered by connection mode - let workers = self.worker_registry.get_workers_filtered( - model_id, - Some(WorkerType::Regular), - Some(ConnectionMode::Grpc { port: None }), - false, // get all workers, we'll filter by is_available() next - ); - - // Filter by availability (health + circuit breaker) - let available: Vec> = workers - .iter() - .filter(|w| w.is_available()) - .cloned() - .collect(); - - if available.is_empty() { - return None; - } - - // Get the appropriate policy for this model - let policy = match model_id { - Some(model) => self.policy_registry.get_policy_or_default(model), - None => self.policy_registry.get_default_policy(), - }; - - // Select worker using the policy - let idx = policy.select_worker(&available, text)?; - Some(available[idx].clone()) - } - - /// Resolve the generate input into optional original text and token IDs - fn resolve_generate_input( - &self, - request: &GenerateRequest, - ) -> Result<(Option, Vec), String> { - if let Some(text) = &request.text { - return self - .tokenize_single_text(text) - .map(|(original, ids)| (Some(original), ids)); - } - - // Handle input_ids - validate and convert - if let Some(input_ids) = &request.input_ids { - return match input_ids { - InputIds::Single(ids) => ids - .iter() - .map(|&id| u32::try_from(id)) - .collect::, _>>() - .map(|converted| (None, converted)) - .map_err(|_| "input_ids must be non-negative".to_string()), - InputIds::Batch(_) => { - Err("Batch input_ids are not supported over gRPC generate yet".to_string()) - } - }; - } - - Err("Either `text` or `input_ids` must be provided".to_string()) - } - - fn tokenize_single_text(&self, text: &str) -> Result<(String, Vec), String> { - let encoding = self - .tokenizer - .encode(text) - .map_err(|e| format!("Tokenization failed: {}", e))?; - Ok((text.to_string(), encoding.token_ids().to_vec())) - } - - /// Submit request and handle non-streaming response for the `/generate` endpoint - async fn handle_non_streaming_generate( - &self, - mut client: SglangSchedulerClient, - request: proto::GenerateRequest, - original_request: &GenerateRequest, - request_id: String, - weight_version: String, - ) -> Response { - let start_time = Instant::now(); - - let stream = match client.generate(request).await { - Ok(stream) => stream, - Err(e) => { - return utils::internal_error_message(format!("Failed to start generation: {}", e)) - } - }; - - // Collect all responses using utils helper - let responses = match utils::collect_stream_responses(stream, "Generate").await { - Ok(responses) => responses, - Err(error_response) => return error_response, - }; - - if responses.is_empty() { - return utils::internal_error_static("No completion received from scheduler"); - } - - // Create stop decoder from sampling params - let params = original_request.sampling_params.as_ref(); - let mut stop_decoder = utils::create_stop_decoder( - &self.tokenizer, - params.and_then(|p| p.stop.as_ref()), - params.and_then(|p| p.stop_token_ids.as_ref()), - params.and_then(|p| p.skip_special_tokens).unwrap_or(true), - params.and_then(|p| p.no_stop_trim).unwrap_or(false), - ); - - // Process each completion - let mut result_array = Vec::new(); - for mut complete in responses { - stop_decoder.reset(); - - // Process tokens through stop decoder - let outputs = match stop_decoder.process_tokens(&complete.output_ids) { - Ok(outputs) => outputs, - Err(e) => { - return utils::internal_error_message(format!( - "Failed to process tokens: {}", - e - )) - } - }; - - // Accumulate text with early breaks - let mut decoded_text = String::new(); - for output in outputs { - match output { - SequenceDecoderOutput::Text(t) => decoded_text.push_str(&t), - SequenceDecoderOutput::StoppedWithText(t) => { - decoded_text.push_str(&t); - break; - } - SequenceDecoderOutput::Stopped => break, - SequenceDecoderOutput::Held => {} - } - } - - // Flush remaining text - if let SequenceDecoderOutput::Text(t) = stop_decoder.flush() { - decoded_text.push_str(&t); - } - - let output_ids = std::mem::take(&mut complete.output_ids); - let finish_reason = std::mem::take(&mut complete.finish_reason); - - // Build base meta_info using json! macro - let mut meta_info = json!({ - "id": request_id.clone(), - "finish_reason": finish_reason, - "prompt_tokens": complete.prompt_tokens, - "weight_version": weight_version.clone(), - "completion_tokens": complete.completion_tokens, - "cached_tokens": complete.cached_tokens, - "e2e_latency": start_time.elapsed().as_secs_f64(), - }); - - let meta_obj = meta_info.as_object_mut().unwrap(); - - // Add matched_stop if present - if let Some(matched) = complete.matched_stop.take() { - use proto::generate_complete::MatchedStop; - let matched_value = match matched { - MatchedStop::MatchedTokenId(id) => json!(id), - MatchedStop::MatchedStopStr(s) => json!(s), - }; - meta_obj.insert("matched_stop".to_string(), matched_value); - } - - result_array.push(json!({ - "text": decoded_text, - "output_ids": output_ids, - "meta_info": meta_info, - })); - } - - Json(result_array).into_response() - } - - /// Submit request and handle streaming response for the `/generate` endpoint - async fn handle_streaming_generate( - &self, - mut client: SglangSchedulerClient, - request: proto::GenerateRequest, - original_request: &GenerateRequest, - request_id: String, - weight_version: String, - ) -> Response { - let tokenizer = self.tokenizer.clone(); - let return_logprob = original_request.return_logprob; - - // Create channel for SSE streaming - let (tx, rx) = - tokio::sync::mpsc::unbounded_channel::>(); - - // Start the stream - let stream = match client.generate(request).await { - Ok(stream) => stream, - Err(e) => { - return utils::internal_error_message(format!("Failed to start generation: {}", e)) - } - }; - - // Spawn async task to process stream - tokio::spawn(async move { - let result = Self::process_generate_streaming( - tokenizer, - stream, - request_id, - weight_version, - return_logprob, - &tx, + // Use pipeline for ALL requests (streaming and non-streaming) + self.pipeline + .execute_generate( + body.clone(), + headers.cloned(), + model_id.map(|s| s.to_string()), + self.shared_components.clone(), ) - .await; - - if let Err(e) = result { - let error_chunk = format!("data: {{\"error\": \"{}\"}}\n\n", e); - let _ = tx.send(Ok(bytes::Bytes::from(error_chunk))); - } - - // Send [DONE] marker - let _ = tx.send(Ok(bytes::Bytes::from("data: [DONE]\n\n"))); - }); - - // Create SSE response stream - let body_stream = tokio_stream::wrappers::UnboundedReceiverStream::new(rx); - Response::builder() - .status(StatusCode::OK) - .header("Content-Type", "text/event-stream") - .header("Cache-Control", "no-cache") - .header("Connection", "keep-alive") - .body(axum::body::Body::from_stream(body_stream)) - .unwrap() - } - - /// Process streaming chunks for generate endpoint - async fn process_generate_streaming( - tokenizer: Arc, - mut stream: impl tokio_stream::Stream> - + Unpin, - request_id: String, - weight_version: String, - _include_logprobs: bool, - tx: &tokio::sync::mpsc::UnboundedSender>, - ) -> Result<(), String> { - use proto::generate_response::Response::{Chunk, Complete, Error}; - use std::time::Instant; - use tokio_stream::StreamExt; - - let start_time = Instant::now(); - - // Track state per index for n>1 case - use std::collections::HashMap; - let mut accumulated_texts: HashMap = HashMap::new(); - let mut completion_tokens_map: HashMap = HashMap::new(); - - while let Some(response) = stream.next().await { - let gen_response = response.map_err(|e| format!("Stream error: {}", e))?; - - match gen_response.response { - Some(Chunk(chunk)) => { - let index = chunk.index; - - // Update completion tokens for this index - let completion_tokens = completion_tokens_map.entry(index).or_insert(0); - *completion_tokens += chunk.token_ids.len() as u32; - - // Decode tokens to text (skip_special_tokens=true to handle newlines correctly) - let chunk_text = tokenizer.decode(&chunk.token_ids, true).unwrap_or_default(); - - // Accumulate text for this index - let accumulated_text = accumulated_texts.entry(index).or_default(); - accumulated_text.push_str(&chunk_text); - - // Generate unique ID per index - let index_id = format!("{}-{}", request_id, index); - - // Build streaming response chunk (SGLang format) - let chunk_response = serde_json::json!({ - "text": accumulated_text.clone(), - "output_ids": chunk.token_ids, - "meta_info": { - "id": index_id, - "finish_reason": null, - "prompt_tokens": chunk.prompt_tokens, - "weight_version": weight_version, - "completion_tokens": *completion_tokens, - "cached_tokens": chunk.cached_tokens - }, - "index": index - }); - - let sse_chunk = format!( - "data: {}\n\n", - serde_json::to_string(&chunk_response).unwrap() - ); - tx.send(Ok(bytes::Bytes::from(sse_chunk))) - .map_err(|_| "Failed to send chunk".to_string())?; - } - Some(Complete(complete)) => { - let index = complete.index; - let accumulated_text = - accumulated_texts.get(&index).cloned().unwrap_or_default(); - let completion_tokens = *completion_tokens_map.get(&index).unwrap_or(&0); - let index_id = format!("{}-{}", request_id, index); - let e2e_latency = start_time.elapsed().as_secs_f64(); - - // Send final chunk with finish_reason (no new tokens in Complete, they were already sent in Chunks) - let finish_response = serde_json::json!({ - "text": accumulated_text, - "output_ids": complete.output_ids[complete.output_ids.len().saturating_sub(1)..].to_vec(), - "meta_info": { - "id": index_id, - "finish_reason": complete.finish_reason, - "prompt_tokens": complete.prompt_tokens, - "weight_version": weight_version, - "completion_tokens": completion_tokens, - "cached_tokens": complete.cached_tokens, - "e2e_latency": e2e_latency - }, - "index": index - }); - - let sse_chunk = format!( - "data: {}\n\n", - serde_json::to_string(&finish_response).unwrap() - ); - tx.send(Ok(bytes::Bytes::from(sse_chunk))) - .map_err(|_| "Failed to send finish chunk".to_string())?; - - // Continue to process all completions if n>1 - } - Some(Error(error)) => { - return Err(error.message); - } - None => continue, - } - } - - Ok(()) + .await } } diff --git a/sgl-router/src/routers/grpc/streaming.rs b/sgl-router/src/routers/grpc/streaming.rs index 0337ce365..fc9a8a68b 100644 --- a/sgl-router/src/routers/grpc/streaming.rs +++ b/sgl-router/src/routers/grpc/streaming.rs @@ -17,15 +17,18 @@ use tokio_stream::StreamExt; use tonic::codec::Streaming; use tracing::{debug, error, warn}; +use super::context; +use super::utils; use crate::grpc_client::proto; use crate::protocols::spec::*; use crate::reasoning_parser::ReasoningParser; use crate::tokenizer::stop::{SequenceDecoderOutput, StopSequenceDecoder}; use crate::tokenizer::traits::Tokenizer; use crate::tool_parser::ToolParser; - -use super::context; -use super::utils; +use proto::generate_complete::MatchedStop::{MatchedStopStr, MatchedTokenId}; +use proto::generate_response::Response::{Chunk, Complete, Error}; +use std::time::Instant; +use tokio::sync::mpsc; /// Shared streaming processor for both single and dual dispatch modes #[derive(Clone)] @@ -65,7 +68,7 @@ impl StreamingProcessor { execution_result: context::ExecutionResult, chat_request: ChatCompletionRequest, dispatch: context::DispatchMetadata, - ) -> axum::response::Response { + ) -> Response { use bytes::Bytes; use tokio::sync::mpsc; @@ -194,7 +197,7 @@ impl StreamingProcessor { let gen_response = response.map_err(|e| format!("Stream error: {}", e))?; match gen_response.response { - Some(proto::generate_response::Response::Chunk(chunk)) => { + Some(Chunk(chunk)) => { let index = chunk.index; // Get or create stop decoder for this index @@ -336,7 +339,7 @@ impl StreamingProcessor { .map_err(|_| "Failed to send content chunk".to_string())?; } } - Some(proto::generate_response::Response::Complete(complete)) => { + Some(Complete(complete)) => { let index = complete.index; // Flush any remaining text for this index's stop_decoder @@ -385,19 +388,17 @@ impl StreamingProcessor { // Extract matched_stop let matched_stop_value = match &complete.matched_stop { - Some(proto::generate_complete::MatchedStop::MatchedTokenId(token_id)) => { + Some(MatchedTokenId(token_id)) => { Some(Value::Number(serde_json::Number::from(*token_id))) } - Some(proto::generate_complete::MatchedStop::MatchedStopStr(stop_str)) => { - Some(Value::String(stop_str.clone())) - } + Some(MatchedStopStr(stop_str)) => Some(Value::String(stop_str.clone())), None => None, }; matched_stops.insert(index, matched_stop_value); // Don't break - continue reading all Complete messages for n>1 } - Some(proto::generate_response::Response::Error(error)) => { + Some(Error(error)) => { return Err(error.message); } None => continue, @@ -536,12 +537,12 @@ impl StreamingProcessor { while let Some(response) = prefill_stream.next().await { let gen_response = response.map_err(|e| format!("Prefill stream error: {}", e))?; match gen_response.response { - Some(proto::generate_response::Response::Complete(_complete)) => { + Some(Complete(_complete)) => { // Input logprobs collected but not yet used in streaming // (OpenAI spec doesn't require prompt logprobs in streaming responses) break; } - Some(proto::generate_response::Response::Error(error)) => { + Some(Error(error)) => { return Err(format!("Prefill error: {}", error.message)); } _ => continue, @@ -554,23 +555,359 @@ impl StreamingProcessor { .await } - // TODO(generate): Add streaming generate handler - // - // pub async fn process_streaming_generate( - // self: Arc, - // execution_result: context::ExecutionResult, - // generate_request: GenerateRequest, - // dispatch: context::DispatchMetadata, - // ) -> axum::response::Response { - // // Similar to process_streaming_response but: - // // - No tool parsing - // // - No reasoning parsing - // // - Simpler chunk format (just text + finish_reason + logprobs) - // // - Extract stop params from generate_request.sampling_params - // // - Use same per-index stop decoder logic - // // - Emit SSE chunks with format similar to chat but without delta.tool_calls - // // Reference: router.rs:422-595 - // } + /// Process streaming generate response and return SSE response + /// + /// Simpler than chat - no tool/reasoning parsing, just text accumulation + pub fn process_streaming_generate( + self: Arc, + execution_result: context::ExecutionResult, + generate_request: GenerateRequest, + dispatch: context::DispatchMetadata, + ) -> Response { + let return_logprob = generate_request.return_logprob; + + // Create SSE channel + let (tx, rx) = mpsc::unbounded_channel::>(); + + // Spawn background task based on execution mode + match execution_result { + context::ExecutionResult::Single { stream } => { + let tokenizer = self.tokenizer.clone(); + let request_id = dispatch.request_id.clone(); + let weight_version = dispatch + .weight_version + .clone() + .unwrap_or_else(|| "default".to_string()); + tokio::spawn(async move { + let result = Self::process_generate_streaming( + tokenizer, + stream, + request_id, + weight_version, + return_logprob, + &tx, + ) + .await; + + if let Err(e) = result { + let error_chunk = format!("data: {{\"error\": \"{}\"}}\n\n", e); + let _ = tx.send(Ok(Bytes::from(error_chunk))); + } + + let _ = tx.send(Ok(Bytes::from("data: [DONE]\n\n"))); + }); + } + context::ExecutionResult::Dual { prefill, decode } => { + // For PD mode, need to handle prefill stream for input_logprobs + let tokenizer = self.tokenizer.clone(); + let request_id = dispatch.request_id.clone(); + let weight_version = dispatch + .weight_version + .clone() + .unwrap_or_else(|| "default".to_string()); + tokio::spawn(async move { + let result = Self::process_generate_streaming_dual( + tokenizer, + prefill, + *decode, + request_id, + weight_version, + return_logprob, + &tx, + ) + .await; + + if let Err(e) = result { + let error_chunk = format!("data: {{\"error\": \"{}\"}}\n\n", e); + let _ = tx.send(Ok(Bytes::from(error_chunk))); + } + + let _ = tx.send(Ok(Bytes::from("data: [DONE]\n\n"))); + }); + } + } + + // Return SSE response + build_sse_response(rx) + } + + //TODO add streaming logprob support + /// Process streaming chunks for generate endpoint (no tool/reasoning parsing) + async fn process_generate_streaming( + tokenizer: Arc, + mut stream: Streaming, + request_id: String, + weight_version: String, + _include_logprobs: bool, + tx: &UnboundedSender>, + ) -> Result<(), String> { + let start_time = Instant::now(); + + // Track state per index for n>1 case + let mut accumulated_texts: HashMap = HashMap::new(); + let mut completion_tokens_map: HashMap = HashMap::new(); + + while let Some(response) = stream.next().await { + let gen_response = response.map_err(|e| format!("Stream error: {}", e))?; + + match gen_response.response { + Some(Chunk(chunk)) => { + let index = chunk.index; + + // Update completion tokens for this index + let completion_tokens = completion_tokens_map.entry(index).or_insert(0); + *completion_tokens += chunk.token_ids.len() as u32; + + // Decode tokens to text (skip_special_tokens=true to handle newlines correctly) + let chunk_text = tokenizer.decode(&chunk.token_ids, true).unwrap_or_default(); + + // Accumulate text for this index + let accumulated_text = accumulated_texts.entry(index).or_default(); + accumulated_text.push_str(&chunk_text); + + // Generate unique ID per index + let index_id = format!("{}-{}", request_id, index); + + // Build streaming response chunk (SGLang format) + let chunk_response = serde_json::json!({ + "text": accumulated_text.clone(), + "output_ids": chunk.token_ids, + "meta_info": { + "id": index_id, + "finish_reason": null, + "prompt_tokens": chunk.prompt_tokens, + "weight_version": &weight_version, + "completion_tokens": *completion_tokens, + "cached_tokens": chunk.cached_tokens + }, + "index": index + }); + + let sse_chunk = format!( + "data: {}\n\n", + serde_json::to_string(&chunk_response).unwrap() + ); + tx.send(Ok(Bytes::from(sse_chunk))) + .map_err(|_| "Failed to send chunk".to_string())?; + } + Some(Complete(complete)) => { + let index = complete.index; + let accumulated_text = + accumulated_texts.get(&index).cloned().unwrap_or_default(); + let completion_tokens = *completion_tokens_map.get(&index).unwrap_or(&0); + let index_id = format!("{}-{}", request_id, index); + let e2e_latency = start_time.elapsed().as_secs_f64(); + + // Send final chunk with finish_reason + let finish_response = serde_json::json!({ + "text": accumulated_text, + "output_ids": complete.output_ids[complete.output_ids.len().saturating_sub(1)..].to_vec(), + "meta_info": { + "id": index_id, + "finish_reason": complete.finish_reason, + "prompt_tokens": complete.prompt_tokens, + "weight_version": &weight_version, + "completion_tokens": completion_tokens, + "cached_tokens": complete.cached_tokens, + "e2e_latency": e2e_latency + }, + "index": index + }); + + let sse_chunk = format!( + "data: {}\n\n", + serde_json::to_string(&finish_response).unwrap() + ); + tx.send(Ok(Bytes::from(sse_chunk))) + .map_err(|_| "Failed to send finish chunk".to_string())?; + + // Continue to process all completions if n>1 + } + Some(Error(error)) => { + return Err(error.message); + } + None => continue, + } + } + + Ok(()) + } + + /// Process dual streaming for generate endpoint (PD mode with logprobs support) + async fn process_generate_streaming_dual( + tokenizer: Arc, + mut prefill_stream: Streaming, + decode_stream: Streaming, + request_id: String, + weight_version: String, + return_logprob: bool, + tx: &UnboundedSender>, + ) -> Result<(), String> { + // Collect input_logprobs from prefill stream if requested + let input_token_logprobs = if return_logprob { + let mut input_logprobs = None; + while let Some(response) = prefill_stream.next().await { + let gen_response = response.map_err(|e| format!("Prefill stream error: {}", e))?; + match gen_response.response { + Some(Complete(complete)) => { + // Extract input_logprobs from prefill Complete message (convert proto to SGLang format) + input_logprobs = complete + .input_logprobs + .as_ref() + .map(utils::convert_generate_input_logprobs); + break; + } + Some(Error(error)) => { + return Err(format!("Prefill error: {}", error.message)); + } + _ => continue, + } + } + input_logprobs + } else { + None + }; + + // Process decode stream with input_logprobs prepended + Self::process_generate_streaming_with_input_logprobs( + tokenizer, + decode_stream, + request_id, + weight_version, + return_logprob, + input_token_logprobs, + tx, + ) + .await + } + + /// Process generate streaming with optional input_logprobs + async fn process_generate_streaming_with_input_logprobs( + tokenizer: Arc, + mut stream: Streaming, + request_id: String, + weight_version: String, + _include_logprobs: bool, + input_token_logprobs: Option>>>, + tx: &UnboundedSender>, + ) -> Result<(), String> { + let start_time = Instant::now(); + + // Track state per index for n>1 case + let mut accumulated_texts: HashMap = HashMap::new(); + let mut accumulated_output_logprobs: HashMap>>>> = + HashMap::new(); + let mut completion_tokens_map: HashMap = HashMap::new(); + + while let Some(response) = stream.next().await { + let gen_response = response.map_err(|e| format!("Stream error: {}", e))?; + + match gen_response.response { + Some(Chunk(chunk)) => { + let index = chunk.index; + + // Update completion tokens for this index + let completion_tokens = completion_tokens_map.entry(index).or_insert(0); + *completion_tokens += chunk.token_ids.len() as u32; + + // Decode tokens to text + let chunk_text = tokenizer.decode(&chunk.token_ids, true).unwrap_or_default(); + + // Accumulate text for this index + let accumulated_text = accumulated_texts.entry(index).or_default(); + accumulated_text.push_str(&chunk_text); + + // Store latest output logprobs (cumulative from proto, convert to SGLang format) + if let Some(ref output_logprobs) = chunk.output_logprobs { + let converted = + super::utils::convert_generate_output_logprobs(output_logprobs); + accumulated_output_logprobs.insert(index, Some(converted)); + } + + // Generate unique ID per index + let index_id = format!("{}-{}", request_id, index); + + // Build streaming response chunk with cumulative logprobs + let current_output_logprobs = accumulated_output_logprobs + .get(&index) + .and_then(|o| o.as_ref()); + + let chunk_response = serde_json::json!({ + "text": accumulated_text.clone(), + "output_ids": chunk.token_ids, + "meta_info": { + "id": index_id, + "finish_reason": null, + "prompt_tokens": chunk.prompt_tokens, + "weight_version": &weight_version, + "input_token_logprobs": input_token_logprobs.as_ref(), + "output_token_logprobs": current_output_logprobs, + "completion_tokens": *completion_tokens, + "cached_tokens": chunk.cached_tokens + }, + "index": index + }); + + let sse_chunk = format!( + "data: {}\n\n", + serde_json::to_string(&chunk_response).unwrap() + ); + tx.send(Ok(Bytes::from(sse_chunk))) + .map_err(|_| "Failed to send chunk".to_string())?; + } + Some(Complete(complete)) => { + let index = complete.index; + let accumulated_text = + accumulated_texts.get(&index).cloned().unwrap_or_default(); + let completion_tokens = *completion_tokens_map.get(&index).unwrap_or(&0); + let final_output_logprobs = accumulated_output_logprobs + .get(&index) + .and_then(|o| o.as_ref()); + let index_id = format!("{}-{}", request_id, index); + let e2e_latency = start_time.elapsed().as_secs_f64(); + + // Parse finish_reason + let finish_reason = utils::parse_finish_reason( + &complete.finish_reason, + complete.completion_tokens, + ); + + // Send final chunk with finish_reason + let finish_response = serde_json::json!({ + "text": accumulated_text, + "output_ids": complete.output_ids[complete.output_ids.len().saturating_sub(1)..].to_vec(), + "meta_info": { + "id": index_id, + "finish_reason": finish_reason, + "prompt_tokens": complete.prompt_tokens, + "weight_version": &weight_version, + "input_token_logprobs": input_token_logprobs.as_ref(), + "output_token_logprobs": final_output_logprobs, + "completion_tokens": completion_tokens, + "cached_tokens": complete.cached_tokens, + "e2e_latency": e2e_latency + }, + "index": index + }); + + let sse_chunk = format!( + "data: {}\n\n", + serde_json::to_string(&finish_response).unwrap() + ); + tx.send(Ok(Bytes::from(sse_chunk))) + .map_err(|_| "Failed to send finish chunk".to_string())?; + + // Continue to process all completions if n>1 + } + Some(Error(error)) => { + return Err(error.message); + } + None => continue, + } + } + + Ok(()) + } // ======================================================================== // Helper Methods @@ -842,9 +1179,7 @@ impl StreamingProcessor { } /// Build SSE response with proper headers -pub fn build_sse_response( - rx: tokio::sync::mpsc::UnboundedReceiver>, -) -> Response { +pub fn build_sse_response(rx: mpsc::UnboundedReceiver>) -> Response { let stream = UnboundedReceiverStream::new(rx); let mut response = Response::new(Body::from_stream(stream)); *response.status_mut() = StatusCode::OK; diff --git a/sgl-router/src/routers/grpc/utils.rs b/sgl-router/src/routers/grpc/utils.rs index c03f2bc8d..cc05cb32d 100644 --- a/sgl-router/src/routers/grpc/utils.rs +++ b/sgl-router/src/routers/grpc/utils.rs @@ -5,7 +5,7 @@ use crate::core::Worker; use crate::grpc_client::{proto, SglangSchedulerClient}; use crate::protocols::spec::{ ChatCompletionRequest, ChatLogProbs, ChatLogProbsContent, ChatMessage, FunctionCallResponse, - StringOrArray, Tool, ToolCall, ToolChoice, ToolChoiceValue, TopLogProb, + GenerateFinishReason, StringOrArray, Tool, ToolCall, ToolChoice, ToolChoiceValue, TopLogProb, }; use crate::tokenizer::chat_template::{ChatTemplateContentFormat, ChatTemplateParams}; use crate::tokenizer::traits::Tokenizer; @@ -809,6 +809,70 @@ pub fn convert_proto_to_openai_logprobs( }) } +/// Convert proto::OutputLogProbs to Generate format Vec>> +/// +/// Generate format: [[logprob, token_id, ...], [logprob, token_id, ...], ...] +/// Each inner vec contains [logprob (f64), token_id (i32), ...] +pub fn convert_generate_output_logprobs( + proto_logprobs: &proto::OutputLogProbs, +) -> Vec>> { + proto_logprobs + .token_logprobs + .iter() + .zip(proto_logprobs.token_ids.iter()) + .map(|(&logprob, &token_id)| vec![Some(logprob as f64), Some(token_id as f64)]) + .collect() +} + +/// Convert proto::InputLogProbs to Generate format Vec>> +/// +/// Generate format: [[logprob, token_id, ...], [logprob, token_id, ...], ...] +/// First token has null logprob: [[null, token_id], [logprob, token_id], ...] +pub fn convert_generate_input_logprobs( + proto_logprobs: &proto::InputLogProbs, +) -> Vec>> { + proto_logprobs + .token_logprobs + .iter() + .zip(proto_logprobs.token_ids.iter()) + .map(|(token_logprob, &token_id)| { + // InputTokenLogProb has optional value field + let logprob_value = token_logprob.value.map(|v| v as f64); + vec![logprob_value, Some(token_id as f64)] + }) + .collect() +} + +/// Parse finish_reason string into GenerateFinishReason enum +/// +/// Uses serde to deserialize the finish_reason, which handles all tagged variants automatically. +/// The GenerateFinishReason enum is tagged with `#[serde(tag = "type", rename_all = "lowercase")]`, +/// so it expects JSON objects like: +/// - `{"type":"stop"}` -> Stop +/// - `{"type":"length","length":100}` -> Length { length: 100 } +/// - Any other JSON -> Other(...) +/// +/// For backward compatibility, also handles simple string "stop" -> Stop +pub fn parse_finish_reason(reason_str: &str, completion_tokens: i32) -> GenerateFinishReason { + if reason_str == "stop" { + return GenerateFinishReason::Stop; + } + + if reason_str == "length" { + return GenerateFinishReason::Length { + length: completion_tokens.max(0) as u32, + }; + } + + match serde_json::from_str::(reason_str) { + Ok(finish_reason) => finish_reason, + Err(_) => match serde_json::from_str::(reason_str) { + Ok(json_value) => GenerateFinishReason::Other(json_value), + Err(_) => GenerateFinishReason::Other(Value::String(reason_str.to_string())), + }, + } +} + #[cfg(test)] mod tests { use super::*;