diff --git a/sgl-router/src/routers/grpc/pd_router.rs b/sgl-router/src/routers/grpc/pd_router.rs index de6f79a2d..ade564e51 100644 --- a/sgl-router/src/routers/grpc/pd_router.rs +++ b/sgl-router/src/routers/grpc/pd_router.rs @@ -23,6 +23,9 @@ use std::sync::Arc; use tracing::debug; +use super::context::SharedComponents; +use super::pipeline::RequestPipeline; + /// gRPC PD (Prefill-Decode) router implementation for SGLang #[derive(Clone)] #[allow(dead_code)] // Fields will be used once implementation is complete @@ -37,8 +40,8 @@ pub struct GrpcPDRouter { retry_config: RetryConfig, configured_reasoning_parser: Option, configured_tool_parser: Option, - pipeline: super::pipeline::ChatCompletionPipeline, - shared_components: Arc, + pipeline: RequestPipeline, + shared_components: Arc, } impl GrpcPDRouter { @@ -66,36 +69,21 @@ impl GrpcPDRouter { .clone(); // Create shared components for pipeline - let shared_components = Arc::new(super::context::SharedComponents { + let shared_components = Arc::new(SharedComponents { tokenizer: tokenizer.clone(), tool_parser_factory: tool_parser_factory.clone(), reasoning_parser_factory: reasoning_parser_factory.clone(), }); - // Create response processor - let processor = super::processing::ResponseProcessor::new( - tokenizer.clone(), - tool_parser_factory.clone(), - reasoning_parser_factory.clone(), - ctx.configured_tool_parser.clone(), - ctx.configured_reasoning_parser.clone(), - ); - - // Create streaming processor - let streaming_processor = Arc::new(super::streaming::StreamingProcessor::new( - tokenizer.clone(), - tool_parser_factory.clone(), - reasoning_parser_factory.clone(), - ctx.configured_tool_parser.clone(), - ctx.configured_reasoning_parser.clone(), - )); - // Create PD pipeline - let pipeline = super::pipeline::ChatCompletionPipeline::new_pd( + let pipeline = RequestPipeline::new_pd( worker_registry.clone(), policy_registry.clone(), - processor, - streaming_processor, + tokenizer.clone(), + tool_parser_factory.clone(), + reasoning_parser_factory.clone(), + ctx.configured_tool_parser.clone(), + ctx.configured_reasoning_parser.clone(), ); Ok(GrpcPDRouter { diff --git a/sgl-router/src/routers/grpc/pipeline.rs b/sgl-router/src/routers/grpc/pipeline.rs index 3be782ebb..1c23bf64a 100644 --- a/sgl-router/src/routers/grpc/pipeline.rs +++ b/sgl-router/src/routers/grpc/pipeline.rs @@ -14,13 +14,10 @@ use super::utils; use crate::core::{ConnectionMode, Worker, WorkerRegistry, WorkerType}; use crate::grpc_client::proto; use crate::policies::PolicyRegistry; -use crate::protocols::spec::{ - ChatCompletionRequest, ChatCompletionResponse, GenerateMetaInfo, GenerateRequest, - GenerateResponse, InputIds, Usage, -}; -use crate::tokenizer::stop::SequenceDecoderOutput; +use crate::protocols::spec::{ChatCompletionRequest, GenerateRequest, InputIds}; +use crate::reasoning_parser::ParserFactory as ReasoningParserFactory; use crate::tokenizer::traits::Tokenizer; -use proto::generate_complete::MatchedStop; +use crate::tool_parser::ParserFactory as ToolParserFactory; use proto::DisaggregatedParams; use rand::Rng; use std::sync::Arc; @@ -790,114 +787,32 @@ impl ResponseProcessingStage { .take() .ok_or_else(|| utils::internal_error_static("No execution result"))?; - if is_streaming { - // Get dispatch metadata for consistent response fields - let dispatch = ctx - .state - .dispatch - .as_ref() - .ok_or_else(|| utils::internal_error_static("Dispatch metadata not set"))?; + // Get dispatch metadata (needed by both streaming and non-streaming) + let dispatch = ctx + .state + .dispatch + .as_ref() + .ok_or_else(|| utils::internal_error_static("Dispatch metadata not set"))? + .clone(); + if is_streaming { // Streaming: Use StreamingProcessor and return SSE response (done) return Ok(Some( self.streaming_processor.clone().process_streaming_response( execution_result, ctx.chat_request_arc(), // Cheap Arc clone (8 bytes) - dispatch.clone(), + dispatch, ), )); } - // Non-streaming: Extract chat request details before mutable borrows + // Non-streaming: Delegate to ResponseProcessor let request_logprobs = match &ctx.input.request_type { RequestType::Chat(req) => req.logprobs, _ => false, }; - // Collect all responses from the execution result - let all_responses = match execution_result { - ExecutionResult::Single { mut stream } => { - let responses = utils::collect_stream_responses(&mut stream, "Single").await?; - stream.mark_completed(); - responses - } - ExecutionResult::Dual { - mut prefill, - decode, - } => { - // Collect prefill for input_logprobs (don't mark completed yet) - let prefill_responses = - utils::collect_stream_responses(&mut prefill, "Prefill").await?; - - // Collect decode for actual output (don't mark completed yet) - let mut decode_stream = *decode; - let mut decode_responses = - utils::collect_stream_responses(&mut decode_stream, "Decode").await?; - - // Mark both streams as completed now that both succeeded - prefill.mark_completed(); - decode_stream.mark_completed(); - - // Merge prefill input_logprobs if requested - if request_logprobs { - if let Some(prefill_input_logprobs) = prefill_responses - .first() - .and_then(|r| r.input_logprobs.clone()) - { - for response in &mut decode_responses { - response.input_logprobs = Some(prefill_input_logprobs.clone()); - } - } - } - - decode_responses - } - }; - - if all_responses.is_empty() { - return Err(utils::internal_error_static("No responses from server")); - } - let chat_request = ctx.chat_request_arc(); - let history_tool_calls_count = utils::get_history_tool_calls_count(&chat_request); - - // Check parser availability once upfront (not per choice) - let reasoning_parser_available = chat_request.separate_reasoning - && utils::check_reasoning_parser_availability( - &self.processor.reasoning_parser_factory, - self.processor.configured_reasoning_parser.as_ref(), - &chat_request.model, - ); - - let tool_choice_enabled = !matches!( - &chat_request.tool_choice, - Some(crate::protocols::spec::ToolChoice::Value( - crate::protocols::spec::ToolChoiceValue::None - )) - ); - - let tool_parser_available = tool_choice_enabled - && chat_request.tools.is_some() - && utils::check_tool_parser_availability( - &self.processor.tool_parser_factory, - self.processor.configured_tool_parser.as_ref(), - &chat_request.model, - ); - - // Log once per request (not per choice) - if chat_request.separate_reasoning && !reasoning_parser_available { - debug!( - "No reasoning parser found for model '{}', skipping reasoning parsing", - chat_request.model - ); - } - - if chat_request.tools.is_some() && tool_choice_enabled && !tool_parser_available { - debug!( - "No tool parser found for model '{}', skipping tool call parsing", - chat_request.model - ); - } let stop_decoder = ctx .state @@ -906,60 +821,16 @@ impl ResponseProcessingStage { .as_mut() .ok_or_else(|| utils::internal_error_static("Stop decoder not initialized"))?; - let mut choices = Vec::new(); - for (index, complete) in all_responses.iter().enumerate() { - match self - .processor - .process_single_choice( - complete, - index, - &chat_request, - stop_decoder, - history_tool_calls_count, - reasoning_parser_available, - tool_parser_available, - ) - .await - { - Ok(choice) => choices.push(choice), - Err(e) => { - return Err(utils::internal_error_message(format!( - "Failed to process choice {}: {}", - index, e - ))); - } - } - } - - // Build usage - let total_prompt_tokens: u32 = all_responses.iter().map(|r| r.prompt_tokens as u32).sum(); - let total_completion_tokens: u32 = all_responses - .iter() - .map(|r| r.completion_tokens as u32) - .sum(); - let usage = Usage { - prompt_tokens: total_prompt_tokens, - completion_tokens: total_completion_tokens, - total_tokens: total_prompt_tokens + total_completion_tokens, - completion_tokens_details: None, - }; - - // Build final ChatCompletionResponse - let dispatch = ctx - .state - .dispatch - .as_ref() - .ok_or_else(|| utils::internal_error_static("Dispatch metadata not set"))?; - - let response = ChatCompletionResponse { - id: dispatch.request_id.clone(), - object: "chat.completion".to_string(), - created: dispatch.created, - model: dispatch.model.clone(), - choices, - usage: Some(usage), - system_fingerprint: dispatch.weight_version.clone(), - }; + let response = self + .processor + .process_non_streaming_chat_response( + execution_result, + chat_request, + dispatch, + stop_decoder, + request_logprobs, + ) + .await?; // Store the final response ctx.state.response.final_response = Some(FinalResponse::Chat(response)); @@ -982,70 +853,29 @@ impl ResponseProcessingStage { .take() .ok_or_else(|| utils::internal_error_static("No execution result"))?; - if is_streaming { - // Get dispatch metadata for consistent response fields - let dispatch = ctx - .state - .dispatch - .as_ref() - .ok_or_else(|| utils::internal_error_static("Dispatch metadata not set"))?; + // Get dispatch metadata (needed by both streaming and non-streaming) + let dispatch = ctx + .state + .dispatch + .as_ref() + .ok_or_else(|| utils::internal_error_static("Dispatch metadata not set"))? + .clone(); + if is_streaming { // Streaming: Use StreamingProcessor and return SSE response (done) return Ok(Some( self.streaming_processor.clone().process_streaming_generate( execution_result, ctx.generate_request_arc(), // Cheap Arc clone (8 bytes) - dispatch.clone(), + dispatch, ), )); } - // Non-streaming: Collect all responses + // Non-streaming: Delegate to ResponseProcessor let request_logprobs = ctx.generate_request().return_logprob; - let all_responses = match execution_result { - ExecutionResult::Single { mut stream } => { - let responses = utils::collect_stream_responses(&mut stream, "Single").await?; - stream.mark_completed(); - responses - } - ExecutionResult::Dual { - mut prefill, - decode, - } => { - // Collect prefill for input_logprobs (don't mark completed yet) - let prefill_responses = - utils::collect_stream_responses(&mut prefill, "Prefill").await?; + let generate_request = ctx.generate_request_arc(); - // Collect decode for actual output (don't mark completed yet) - let mut decode_stream = *decode; - let mut decode_responses = - utils::collect_stream_responses(&mut decode_stream, "Decode").await?; - - // Mark both streams as completed now that both succeeded - prefill.mark_completed(); - decode_stream.mark_completed(); - - // Merge prefill input_logprobs if requested - if request_logprobs { - if let Some(prefill_input_logprobs) = prefill_responses - .first() - .and_then(|r| r.input_logprobs.clone()) - { - for response in &mut decode_responses { - response.input_logprobs = Some(prefill_input_logprobs.clone()); - } - } - } - - decode_responses - } - }; - - if all_responses.is_empty() { - return Err(utils::internal_error_static("No responses from server")); - } - - // Get stop decoder for processing let stop_decoder = ctx .state .response @@ -1053,103 +883,17 @@ impl ResponseProcessingStage { .as_mut() .ok_or_else(|| utils::internal_error_static("Stop decoder not initialized"))?; - // Get dispatch metadata - let dispatch = ctx - .state - .dispatch - .as_ref() - .ok_or_else(|| utils::internal_error_static("Dispatch metadata not set"))?; - - // Process each completion (similar to router.rs:336-400) - let mut result_array = Vec::new(); - for mut complete in all_responses { - stop_decoder.reset(); - - // Process tokens through stop decoder - let outputs = match stop_decoder.process_tokens(&complete.output_ids) { - Ok(outputs) => outputs, - Err(e) => { - return Err(utils::internal_error_message(format!( - "Failed to process tokens: {}", - e - ))) - } - }; - - // Accumulate text with early breaks - let mut decoded_text = String::new(); - for output in outputs { - match output { - SequenceDecoderOutput::Text(t) => decoded_text.push_str(&t), - SequenceDecoderOutput::StoppedWithText(t) => { - decoded_text.push_str(&t); - break; - } - SequenceDecoderOutput::Stopped => break, - SequenceDecoderOutput::Held => {} - } - } - - // Flush remaining text - if let SequenceDecoderOutput::Text(t) = stop_decoder.flush() { - decoded_text.push_str(&t); - } - - let output_ids = std::mem::take(&mut complete.output_ids); - let finish_reason_str = std::mem::take(&mut complete.finish_reason); - - // Parse finish_reason from string to proper type - let finish_reason = - utils::parse_finish_reason(&finish_reason_str, complete.completion_tokens); - - // Handle matched_stop if present - let matched_stop = complete.matched_stop.take().map(|matched| match matched { - MatchedStop::MatchedTokenId(id) => serde_json::json!(id), - MatchedStop::MatchedStopStr(s) => serde_json::json!(s), - }); - - // Extract logprobs if requested (convert proto types to Generate format) - let input_token_logprobs = if request_logprobs { - complete - .input_logprobs - .as_ref() - .map(utils::convert_generate_input_logprobs) - } else { - None - }; - - let output_token_logprobs = if request_logprobs { - complete - .output_logprobs - .as_ref() - .map(utils::convert_generate_output_logprobs) - } else { - None - }; - - // Build GenerateResponse struct - let meta_info = GenerateMetaInfo { - id: dispatch.request_id.clone(), - finish_reason, - prompt_tokens: complete.prompt_tokens as u32, - weight_version: dispatch - .weight_version - .clone() - .unwrap_or_else(|| "default".to_string()), - input_token_logprobs, - output_token_logprobs, - completion_tokens: complete.completion_tokens as u32, - cached_tokens: complete.cached_tokens as u32, - e2e_latency: start_time.elapsed().as_secs_f64(), - matched_stop, - }; - - result_array.push(GenerateResponse { - text: decoded_text, - output_ids, - meta_info, - }); - } + let result_array = self + .processor + .process_non_streaming_generate_response( + execution_result, + generate_request, + dispatch, + stop_decoder, + request_logprobs, + start_time, + ) + .await?; // Store the final response ctx.state.response.final_response = Some(FinalResponse::Generate(result_array)); @@ -1162,23 +906,44 @@ impl ResponseProcessingStage { // Pipeline Orchestrator // ============================================================================ -/// Complete chat completion pipeline +/// Generic request pipeline for all request types /// /// Orchestrates all stages from request preparation to response delivery. /// Configured differently for regular vs PD mode. #[derive(Clone)] -pub struct ChatCompletionPipeline { +pub struct RequestPipeline { stages: Arc>>, } -impl ChatCompletionPipeline { +impl RequestPipeline { /// Create a regular (single-worker) pipeline pub fn new_regular( worker_registry: Arc, policy_registry: Arc, - processor: processing::ResponseProcessor, - streaming_processor: Arc, + tokenizer: Arc, + tool_parser_factory: ToolParserFactory, + reasoning_parser_factory: ReasoningParserFactory, + configured_tool_parser: Option, + configured_reasoning_parser: Option, ) -> Self { + // Create response processor + let processor = processing::ResponseProcessor::new( + tokenizer.clone(), + tool_parser_factory.clone(), + reasoning_parser_factory.clone(), + configured_tool_parser.clone(), + configured_reasoning_parser.clone(), + ); + + // Create streaming processor + let streaming_processor = Arc::new(streaming::StreamingProcessor::new( + tokenizer, + tool_parser_factory, + reasoning_parser_factory, + configured_tool_parser, + configured_reasoning_parser, + )); + let stages: Vec> = vec![ Box::new(PreparationStage), Box::new(WorkerSelectionStage::new( @@ -1190,10 +955,7 @@ impl ChatCompletionPipeline { Box::new(RequestBuildingStage::new(false)), // No PD metadata Box::new(DispatchMetadataStage), Box::new(RequestExecutionStage::new(ExecutionMode::Single)), - Box::new(ResponseProcessingStage::new( - processor, - streaming_processor.clone(), - )), + Box::new(ResponseProcessingStage::new(processor, streaming_processor)), ]; Self { @@ -1205,9 +967,30 @@ impl ChatCompletionPipeline { pub fn new_pd( worker_registry: Arc, policy_registry: Arc, - processor: processing::ResponseProcessor, - streaming_processor: Arc, + tokenizer: Arc, + tool_parser_factory: ToolParserFactory, + reasoning_parser_factory: ReasoningParserFactory, + configured_tool_parser: Option, + configured_reasoning_parser: Option, ) -> Self { + // Create response processor + let processor = processing::ResponseProcessor::new( + tokenizer.clone(), + tool_parser_factory.clone(), + reasoning_parser_factory.clone(), + configured_tool_parser.clone(), + configured_reasoning_parser.clone(), + ); + + // Create streaming processor + let streaming_processor = Arc::new(streaming::StreamingProcessor::new( + tokenizer, + tool_parser_factory, + reasoning_parser_factory, + configured_tool_parser, + configured_reasoning_parser, + )); + let stages: Vec> = vec![ Box::new(PreparationStage), Box::new(WorkerSelectionStage::new( @@ -1219,10 +1002,7 @@ impl ChatCompletionPipeline { Box::new(RequestBuildingStage::new(true)), // Inject PD metadata Box::new(DispatchMetadataStage), Box::new(RequestExecutionStage::new(ExecutionMode::DualDispatch)), - Box::new(ResponseProcessingStage::new( - processor, - streaming_processor.clone(), - )), + Box::new(ResponseProcessingStage::new(processor, streaming_processor)), ]; Self { diff --git a/sgl-router/src/routers/grpc/processing.rs b/sgl-router/src/routers/grpc/processing.rs index 50718ea2c..294b7d6af 100644 --- a/sgl-router/src/routers/grpc/processing.rs +++ b/sgl-router/src/routers/grpc/processing.rs @@ -10,14 +10,18 @@ use tracing::error; use crate::grpc_client::proto; use crate::protocols::spec::{ - ChatChoice, ChatCompletionMessage, ChatCompletionRequest, FunctionCallResponse, ToolCall, - ToolChoice, ToolChoiceValue, + ChatChoice, ChatCompletionMessage, ChatCompletionRequest, ChatCompletionResponse, + FunctionCallResponse, GenerateMetaInfo, GenerateRequest, GenerateResponse, ToolCall, + ToolChoice, ToolChoiceValue, Usage, }; use crate::reasoning_parser::ParserFactory as ReasoningParserFactory; use crate::tokenizer::stop::{SequenceDecoderOutput, StopSequenceDecoder}; use crate::tokenizer::traits::Tokenizer; use crate::tool_parser::ParserFactory as ToolParserFactory; +use proto::generate_complete::MatchedStop; +use std::time::Instant; +use super::context::{DispatchMetadata, ExecutionResult}; use super::utils; // ============================================================================ @@ -51,6 +55,57 @@ impl ResponseProcessor { } } + /// Helper to collect responses from execution result and merge logprobs if needed + async fn collect_and_merge_responses( + execution_result: ExecutionResult, + request_logprobs: bool, + ) -> Result, axum::response::Response> { + let all_responses = match execution_result { + ExecutionResult::Single { mut stream } => { + let responses = utils::collect_stream_responses(&mut stream, "Single").await?; + stream.mark_completed(); + responses + } + ExecutionResult::Dual { + mut prefill, + decode, + } => { + // Collect prefill for input_logprobs (don't mark completed yet) + let prefill_responses = + utils::collect_stream_responses(&mut prefill, "Prefill").await?; + + // Collect decode for actual output (don't mark completed yet) + let mut decode_stream = *decode; + let mut decode_responses = + utils::collect_stream_responses(&mut decode_stream, "Decode").await?; + + // Mark both streams as completed now that both succeeded + prefill.mark_completed(); + decode_stream.mark_completed(); + + // Merge prefill input_logprobs if requested + if request_logprobs { + if let Some(prefill_input_logprobs) = prefill_responses + .first() + .and_then(|r| r.input_logprobs.clone()) + { + for response in &mut decode_responses { + response.input_logprobs = Some(prefill_input_logprobs.clone()); + } + } + } + + decode_responses + } + }; + + if all_responses.is_empty() { + return Err(utils::internal_error_static("No responses from server")); + } + + Ok(all_responses) + } + /// Process a single choice from GenerateComplete response (EXACT COPY from router.rs:1573-1725) #[allow(clippy::too_many_arguments)] pub async fn process_single_choice( @@ -158,12 +213,10 @@ impl ResponseProcessor { // Extract matched_stop information from proto let matched_stop = match &complete.matched_stop { - Some(proto::generate_complete::MatchedStop::MatchedTokenId(token_id)) => { + Some(MatchedStop::MatchedTokenId(token_id)) => { Some(Value::Number(serde_json::Number::from(*token_id))) } - Some(proto::generate_complete::MatchedStop::MatchedStopStr(stop_str)) => { - Some(Value::String(stop_str.clone())) - } + Some(MatchedStop::MatchedStopStr(stop_str)) => Some(Value::String(stop_str.clone())), None => None, }; @@ -205,6 +258,109 @@ impl ResponseProcessor { Ok(choice) } + /// Process non-streaming chat response (collects all responses and builds final response) + pub async fn process_non_streaming_chat_response( + &self, + execution_result: ExecutionResult, + chat_request: Arc, + dispatch: DispatchMetadata, + stop_decoder: &mut StopSequenceDecoder, + request_logprobs: bool, + ) -> Result { + // Collect all responses from the execution result + let all_responses = + Self::collect_and_merge_responses(execution_result, request_logprobs).await?; + + let history_tool_calls_count = utils::get_history_tool_calls_count(&chat_request); + + // Check parser availability once upfront (not per choice) + let reasoning_parser_available = chat_request.separate_reasoning + && utils::check_reasoning_parser_availability( + &self.reasoning_parser_factory, + self.configured_reasoning_parser.as_ref(), + &chat_request.model, + ); + + let tool_choice_enabled = !matches!( + &chat_request.tool_choice, + Some(ToolChoice::Value(ToolChoiceValue::None)) + ); + + let tool_parser_available = tool_choice_enabled + && chat_request.tools.is_some() + && utils::check_tool_parser_availability( + &self.tool_parser_factory, + self.configured_tool_parser.as_ref(), + &chat_request.model, + ); + + // Log once per request (not per choice) + if chat_request.separate_reasoning && !reasoning_parser_available { + tracing::debug!( + "No reasoning parser found for model '{}', skipping reasoning parsing", + chat_request.model + ); + } + + if chat_request.tools.is_some() && tool_choice_enabled && !tool_parser_available { + tracing::debug!( + "No tool parser found for model '{}', skipping tool call parsing", + chat_request.model + ); + } + + // Process all choices + let mut choices = Vec::new(); + for (index, complete) in all_responses.iter().enumerate() { + match self + .process_single_choice( + complete, + index, + &chat_request, + stop_decoder, + history_tool_calls_count, + reasoning_parser_available, + tool_parser_available, + ) + .await + { + Ok(choice) => choices.push(choice), + Err(e) => { + return Err(utils::internal_error_message(format!( + "Failed to process choice {}: {}", + index, e + ))); + } + } + } + + // Build usage + let total_prompt_tokens: u32 = all_responses.iter().map(|r| r.prompt_tokens as u32).sum(); + let total_completion_tokens: u32 = all_responses + .iter() + .map(|r| r.completion_tokens as u32) + .sum(); + let usage = Usage { + prompt_tokens: total_prompt_tokens, + completion_tokens: total_completion_tokens, + total_tokens: total_prompt_tokens + total_completion_tokens, + completion_tokens_details: None, + }; + + // Build final ChatCompletionResponse + let response = ChatCompletionResponse { + id: dispatch.request_id.clone(), + object: "chat.completion".to_string(), + created: dispatch.created, + model: dispatch.model.clone(), + choices, + usage: Some(usage), + system_fingerprint: dispatch.weight_version.clone(), + }; + + Ok(response) + } + /// Parse tool calls using model-specific parser (EXACT COPY from router.rs:296-361) pub async fn parse_tool_calls( &self, @@ -264,4 +420,112 @@ impl ResponseProcessor { } } } + + /// Process non-streaming generate response (collects all responses and builds final response array) + pub async fn process_non_streaming_generate_response( + &self, + execution_result: ExecutionResult, + _generate_request: Arc, + dispatch: DispatchMetadata, + stop_decoder: &mut StopSequenceDecoder, + request_logprobs: bool, + start_time: Instant, + ) -> Result, axum::response::Response> { + // Collect all responses from the execution result + let all_responses = + Self::collect_and_merge_responses(execution_result, request_logprobs).await?; + + // Process each completion + let mut result_array = Vec::new(); + for mut complete in all_responses { + stop_decoder.reset(); + + // Process tokens through stop decoder + let outputs = match stop_decoder.process_tokens(&complete.output_ids) { + Ok(outputs) => outputs, + Err(e) => { + return Err(utils::internal_error_message(format!( + "Failed to process tokens: {}", + e + ))) + } + }; + + // Accumulate text with early breaks + let mut decoded_text = String::new(); + for output in outputs { + match output { + SequenceDecoderOutput::Text(t) => decoded_text.push_str(&t), + SequenceDecoderOutput::StoppedWithText(t) => { + decoded_text.push_str(&t); + break; + } + SequenceDecoderOutput::Stopped => break, + SequenceDecoderOutput::Held => {} + } + } + + // Flush remaining text + if let SequenceDecoderOutput::Text(t) = stop_decoder.flush() { + decoded_text.push_str(&t); + } + + let output_ids = std::mem::take(&mut complete.output_ids); + let finish_reason_str = std::mem::take(&mut complete.finish_reason); + + // Parse finish_reason from string to proper type + let finish_reason = + utils::parse_finish_reason(&finish_reason_str, complete.completion_tokens); + + // Handle matched_stop if present + let matched_stop = complete.matched_stop.take().map(|matched| match matched { + MatchedStop::MatchedTokenId(id) => serde_json::json!(id), + MatchedStop::MatchedStopStr(s) => serde_json::json!(s), + }); + + // Extract logprobs if requested (convert proto types to Generate format) + let input_token_logprobs = if request_logprobs { + complete + .input_logprobs + .as_ref() + .map(utils::convert_generate_input_logprobs) + } else { + None + }; + + let output_token_logprobs = if request_logprobs { + complete + .output_logprobs + .as_ref() + .map(utils::convert_generate_output_logprobs) + } else { + None + }; + + // Build GenerateResponse struct + let meta_info = GenerateMetaInfo { + id: dispatch.request_id.clone(), + finish_reason, + prompt_tokens: complete.prompt_tokens as u32, + weight_version: dispatch + .weight_version + .clone() + .unwrap_or_else(|| "default".to_string()), + input_token_logprobs, + output_token_logprobs, + completion_tokens: complete.completion_tokens as u32, + cached_tokens: complete.cached_tokens as u32, + e2e_latency: start_time.elapsed().as_secs_f64(), + matched_stop, + }; + + result_array.push(GenerateResponse { + text: decoded_text, + output_ids, + meta_info, + }); + } + + Ok(result_array) + } } diff --git a/sgl-router/src/routers/grpc/router.rs b/sgl-router/src/routers/grpc/router.rs index 5666823de..d167e7036 100644 --- a/sgl-router/src/routers/grpc/router.rs +++ b/sgl-router/src/routers/grpc/router.rs @@ -24,6 +24,9 @@ use crate::server::AppContext; use crate::tokenizer::traits::Tokenizer; use crate::tool_parser::ParserFactory as ToolParserFactory; +use super::context::SharedComponents; +use super::pipeline::RequestPipeline; + /// gRPC router implementation for SGLang #[derive(Clone)] #[allow(dead_code)] @@ -38,8 +41,8 @@ pub struct GrpcRouter { retry_config: RetryConfig, configured_reasoning_parser: Option, configured_tool_parser: Option, - pipeline: super::pipeline::ChatCompletionPipeline, - shared_components: Arc, + pipeline: RequestPipeline, + shared_components: Arc, } impl GrpcRouter { @@ -66,36 +69,21 @@ impl GrpcRouter { let policy_registry = ctx.policy_registry.clone(); // Create shared components for pipeline - let shared_components = Arc::new(super::context::SharedComponents { + let shared_components = Arc::new(SharedComponents { tokenizer: tokenizer.clone(), tool_parser_factory: tool_parser_factory.clone(), reasoning_parser_factory: reasoning_parser_factory.clone(), }); - // Create response processor - let processor = super::processing::ResponseProcessor::new( - tokenizer.clone(), - tool_parser_factory.clone(), - reasoning_parser_factory.clone(), - ctx.configured_tool_parser.clone(), - ctx.configured_reasoning_parser.clone(), - ); - - // Create streaming processor - let streaming_processor = Arc::new(super::streaming::StreamingProcessor::new( - tokenizer.clone(), - tool_parser_factory.clone(), - reasoning_parser_factory.clone(), - ctx.configured_tool_parser.clone(), - ctx.configured_reasoning_parser.clone(), - )); - // Create pipeline - let pipeline = super::pipeline::ChatCompletionPipeline::new_regular( + let pipeline = RequestPipeline::new_regular( worker_registry.clone(), policy_registry.clone(), - processor, - streaming_processor, + tokenizer.clone(), + tool_parser_factory.clone(), + reasoning_parser_factory.clone(), + ctx.configured_tool_parser.clone(), + ctx.configured_reasoning_parser.clone(), ); Ok(GrpcRouter {