diff --git a/sgl-router/src/protocols/spec.rs b/sgl-router/src/protocols/spec.rs index 2f99811fa..dc60f9c2b 100644 --- a/sgl-router/src/protocols/spec.rs +++ b/sgl-router/src/protocols/spec.rs @@ -2066,6 +2066,40 @@ impl GenerationRequest for GenerateRequest { } } +// TODO(generate): Define GenerateResponse and GenerateChoice structs +// +// Required for pipeline generate response processing (see grpc/pipeline.rs:931-964) +// +// #[derive(Debug, Clone, Serialize, Deserialize)] +// pub struct GenerateResponse { +// pub id: String, +// pub object: String, // "text.completion" +// pub created: u64, +// pub model: String, +// pub choices: Vec, +// #[serde(skip_serializing_if = "Option::is_none")] +// pub usage: Option, +// #[serde(skip_serializing_if = "Option::is_none")] +// pub system_fingerprint: Option, +// } +// +// #[derive(Debug, Clone, Serialize, Deserialize)] +// pub struct GenerateChoice { +// pub index: u32, +// pub text: String, +// #[serde(skip_serializing_if = "Option::is_none")] +// pub output_ids: Option>, +// #[serde(skip_serializing_if = "Option::is_none")] +// pub finish_reason: Option, +// #[serde(skip_serializing_if = "Option::is_none")] +// pub logprobs: Option, +// #[serde(skip_serializing_if = "Option::is_none")] +// pub matched_stop: Option, +// } +// +// Note: Verify if similar structs already exist elsewhere before implementing. +// May need streaming variant (GenerateStreamResponse) as well. + // Constants for rerank API pub const DEFAULT_MODEL_NAME: &str = "default"; diff --git a/sgl-router/src/routers/grpc/context.rs b/sgl-router/src/routers/grpc/context.rs new file mode 100644 index 000000000..f6bd73462 --- /dev/null +++ b/sgl-router/src/routers/grpc/context.rs @@ -0,0 +1,398 @@ +//! Request context types for gRPC router pipeline +//! +//! This module provides the core context types that flow through the router pipeline, +//! eliminating deep parameter passing chains and providing a single source of truth +//! for request state. + +use std::collections::HashMap; +use std::sync::Arc; + +use axum::http::HeaderMap; +use serde_json::Value; + +use crate::core::Worker; +use crate::grpc_client::{proto, SglangSchedulerClient}; +use crate::protocols::spec::{ChatCompletionRequest, ChatCompletionResponse, GenerateRequest}; +use crate::reasoning_parser::ReasoningParserFactory; +use crate::tokenizer::stop::StopSequenceDecoder; +use crate::tokenizer::traits::Tokenizer; +use crate::tool_parser::ToolParserFactory; + +// ============================================================================ +// Core Context Types +// ============================================================================ + +/// Main request processing context +/// +/// This is the single source of truth for all request state as it flows +/// through the pipeline stages. Uses Rust's type system to enforce proper +/// stage ordering at compile time. +pub struct RequestContext { + // === Input (Immutable) === + pub input: RequestInput, + + // === Shared Components (Immutable References) === + pub components: Arc, + + // === Processing State (Mutable, evolves through pipeline) === + pub state: ProcessingState, +} + +/// Immutable request input +pub struct RequestInput { + pub request_type: RequestType, + pub headers: Option, + pub model_id: Option, +} + +/// Request type variants +pub enum RequestType { + Chat(Box), + Generate(Box), +} + +/// Shared components (injected once at creation) +pub struct SharedComponents { + pub tokenizer: Arc, + pub tool_parser_factory: ToolParserFactory, + pub reasoning_parser_factory: ReasoningParserFactory, +} + +/// Mutable processing state (evolves through pipeline stages) +#[derive(Default)] +pub struct ProcessingState { + // Stage 1: Preparation outputs + pub preparation: Option, + + // Stage 2: Worker selection outputs + pub workers: Option, + + // Stage 3: Client acquisition outputs + pub clients: Option, + + // Stage 4: Request building outputs + pub proto_request: Option, + + // Stage 5: Dispatch metadata + pub dispatch: Option, + + // Stage 6: Response processing state + pub response: ResponseState, +} + +// ============================================================================ +// Stage-Specific Output Types +// ============================================================================ + +/// Output from preparation stage (Step 1) +pub struct PreparationOutput { + /// Original text (for chat) or resolved text (for generate) + pub original_text: Option, + + /// Tokenized input + pub token_ids: Vec, + + /// Processed messages (chat only) + pub processed_messages: Option, + + /// Tool call constraints (if applicable) + pub tool_constraints: Option<(String, String)>, + + /// Filtered request (if tools were filtered) + pub filtered_request: Option, +} + +/// Worker selection (Step 2) +pub enum WorkerSelection { + Single { + worker: Arc, + }, + Dual { + prefill: Arc, + decode: Arc, + }, +} + +/// Client selection (Step 3) +pub enum ClientSelection { + Single { + client: SglangSchedulerClient, + }, + Dual { + prefill: SglangSchedulerClient, + decode: SglangSchedulerClient, + }, +} + +/// Dispatch metadata (Step 5) +#[derive(Clone)] +pub struct DispatchMetadata { + pub request_id: String, + pub model: String, + pub created: u64, + pub weight_version: Option, + pub is_streaming: bool, +} + +/// Response processing state (Step 6) +#[derive(Default)] +pub struct ResponseState { + /// Stop sequence decoder + pub stop_decoder: Option, + + /// Per-index streaming state (for n>1 support) + pub streaming: StreamingState, + + /// Collected responses (non-streaming) + pub collected: Option>, + + /// Execution result (streams from workers) + pub execution_result: Option, + + /// Final processed response + pub final_response: Option, +} + +/// Streaming state (per-choice tracking) +#[derive(Default)] +pub struct StreamingState { + pub is_firsts: HashMap, + pub stream_buffers: HashMap, + pub finish_reasons: HashMap, + pub matched_stops: HashMap>, + pub prompt_tokens: HashMap, + pub completion_tokens: HashMap, + pub cached_tokens: HashMap, + + // Parser state (lazy initialization per index) + pub reasoning_parsers: + HashMap>>>, + pub tool_parsers: + HashMap>>>, + pub has_tool_calls: HashMap, +} + +// ============================================================================ +// Context Builders +// ============================================================================ + +impl RequestContext { + /// Create context for chat completion request + pub fn for_chat( + request: ChatCompletionRequest, + headers: Option, + model_id: Option, + components: Arc, + ) -> Self { + Self { + input: RequestInput { + request_type: RequestType::Chat(Box::new(request)), + headers, + model_id, + }, + components, + state: ProcessingState::default(), + } + } + + /// Create context for generate request + pub fn for_generate( + request: GenerateRequest, + headers: Option, + model_id: Option, + components: Arc, + ) -> Self { + Self { + input: RequestInput { + request_type: RequestType::Generate(Box::new(request)), + headers, + model_id, + }, + components, + state: ProcessingState::default(), + } + } + + /// Get reference to original request (type-safe) + pub fn request(&self) -> &RequestType { + &self.input.request_type + } + + /// Get chat request (panics if not chat) + pub fn chat_request(&self) -> &ChatCompletionRequest { + match &self.input.request_type { + RequestType::Chat(req) => req.as_ref(), + _ => panic!("Expected chat request"), + } + } + + /// Try to get chat request + pub fn try_chat_request(&self) -> Option<&ChatCompletionRequest> { + match &self.input.request_type { + RequestType::Chat(req) => Some(req.as_ref()), + _ => None, + } + } + + /// Get generate request (panics if not generate) + pub fn generate_request(&self) -> &GenerateRequest { + match &self.input.request_type { + RequestType::Generate(req) => req.as_ref(), + _ => panic!("Expected generate request"), + } + } + + /// Try to get generate request + pub fn try_generate_request(&self) -> Option<&GenerateRequest> { + match &self.input.request_type { + RequestType::Generate(req) => Some(req.as_ref()), + _ => None, + } + } + + /// Check if request is streaming + pub fn is_streaming(&self) -> bool { + match &self.input.request_type { + RequestType::Chat(req) => req.stream, + RequestType::Generate(req) => req.stream, + } + } + + /// Check if request is chat + pub fn is_chat(&self) -> bool { + matches!(&self.input.request_type, RequestType::Chat(_)) + } + + /// Check if request is generate + pub fn is_generate(&self) -> bool { + matches!(&self.input.request_type, RequestType::Generate(_)) + } +} + +// ============================================================================ +// Default Implementations +// ============================================================================ + +// ============================================================================ +// Helper Methods +// ============================================================================ + +impl WorkerSelection { + pub fn is_dual(&self) -> bool { + matches!(self, Self::Dual { .. }) + } + + pub fn single(&self) -> Option<&Arc> { + match self { + Self::Single { worker } => Some(worker), + _ => None, + } + } + + #[allow(clippy::type_complexity)] + pub fn dual(&self) -> Option<(&Arc, &Arc)> { + match self { + Self::Dual { prefill, decode } => Some((prefill, decode)), + _ => None, + } + } + + pub fn prefill_worker(&self) -> Option<&Arc> { + match self { + Self::Dual { prefill, .. } => Some(prefill), + _ => None, + } + } + + pub fn decode_worker(&self) -> Option<&Arc> { + match self { + Self::Dual { decode, .. } => Some(decode), + _ => None, + } + } +} + +impl ClientSelection { + pub fn is_dual(&self) -> bool { + matches!(self, Self::Dual { .. }) + } + + pub fn single(&self) -> Option<&SglangSchedulerClient> { + match self { + Self::Single { client } => Some(client), + _ => None, + } + } + + pub fn single_mut(&mut self) -> Option<&mut SglangSchedulerClient> { + match self { + Self::Single { client } => Some(client), + _ => None, + } + } + + pub fn dual(&self) -> Option<(&SglangSchedulerClient, &SglangSchedulerClient)> { + match self { + Self::Dual { prefill, decode } => Some((prefill, decode)), + _ => None, + } + } + + pub fn dual_mut(&mut self) -> Option<(&mut SglangSchedulerClient, &mut SglangSchedulerClient)> { + match self { + Self::Dual { prefill, decode } => Some((prefill, decode)), + _ => None, + } + } + + pub fn prefill_client(&self) -> Option<&SglangSchedulerClient> { + match self { + Self::Dual { prefill, .. } => Some(prefill), + _ => None, + } + } + + pub fn prefill_client_mut(&mut self) -> Option<&mut SglangSchedulerClient> { + match self { + Self::Dual { prefill, .. } => Some(prefill), + _ => None, + } + } + + pub fn decode_client(&self) -> Option<&SglangSchedulerClient> { + match self { + Self::Dual { decode, .. } => Some(decode), + _ => None, + } + } + + pub fn decode_client_mut(&mut self) -> Option<&mut SglangSchedulerClient> { + match self { + Self::Dual { decode, .. } => Some(decode), + _ => None, + } + } +} + +// ============================================================================ +// Execution and Response Types +// ============================================================================ + +use tonic::codec::Streaming; + +/// Result of request execution (streams from workers) +pub enum ExecutionResult { + Single { + stream: Streaming, + }, + Dual { + prefill: Streaming, + decode: Box>, + }, +} + +/// Final processed response +pub enum FinalResponse { + Chat(ChatCompletionResponse), + Generate(Box), +} diff --git a/sgl-router/src/routers/grpc/mod.rs b/sgl-router/src/routers/grpc/mod.rs index 03a1f9ac2..2378ae9b9 100644 --- a/sgl-router/src/routers/grpc/mod.rs +++ b/sgl-router/src/routers/grpc/mod.rs @@ -3,8 +3,12 @@ use crate::grpc_client::proto; use crate::protocols::spec::StringOrArray; +pub mod context; pub mod pd_router; +pub mod pipeline; +pub mod processing; pub mod router; +pub mod streaming; pub mod utils; /// Processed chat messages ready for gRPC generation diff --git a/sgl-router/src/routers/grpc/pd_router.rs b/sgl-router/src/routers/grpc/pd_router.rs index 1d39d748f..daad4b9d8 100644 --- a/sgl-router/src/routers/grpc/pd_router.rs +++ b/sgl-router/src/routers/grpc/pd_router.rs @@ -6,19 +6,16 @@ use crate::grpc_client::proto; use crate::grpc_client::SglangSchedulerClient; use crate::policies::PolicyRegistry; use crate::protocols::spec::{ - ChatChoice, ChatCompletionMessage, ChatCompletionRequest, ChatCompletionResponse, - ChatCompletionStreamResponse, ChatLogProbs, ChatLogProbsContent, ChatMessageDelta, - ChatStreamChoice, CompletionRequest, EmbeddingRequest, FunctionCallDelta, FunctionCallResponse, - GenerateRequest, InputIds, RerankRequest, ResponsesGetParams, ResponsesRequest, StringOrArray, - Tool, ToolCall, ToolCallDelta, ToolChoice, ToolChoiceValue, TopLogProb, Usage, + ChatCompletionRequest, CompletionRequest, EmbeddingRequest, GenerateRequest, InputIds, + RerankRequest, ResponsesGetParams, ResponsesRequest, }; -use crate::reasoning_parser::{ParserResult, ReasoningParser, ReasoningParserFactory}; +use crate::reasoning_parser::ReasoningParserFactory; use crate::routers::http::pd_types::generate_room_id; use crate::routers::{grpc, RouterTrait}; use crate::server::AppContext; use crate::tokenizer::traits::Tokenizer; -use crate::tokenizer::{SequenceDecoderOutput, StopSequenceDecoder}; -use crate::tool_parser::{StreamingParseResult, ToolParser, ToolParserFactory}; +use crate::tokenizer::SequenceDecoderOutput; +use crate::tool_parser::ToolParserFactory; use async_trait::async_trait; use axum::{ body::Body, @@ -29,16 +26,14 @@ use axum::{ }; use grpc::utils; use proto::generate_response::Response::{Chunk, Complete, Error}; -use serde_json::Value; use std::collections::HashMap; use std::sync::Arc; use std::time::Instant; -use std::time::{SystemTime, UNIX_EPOCH}; use tokio::sync::mpsc::unbounded_channel; use tokio::sync::mpsc::UnboundedSender; use tokio_stream::Stream; use tokio_stream::StreamExt; -use tracing::{debug, error, warn}; +use tracing::{debug, error}; use uuid::Uuid; /// gRPC PD (Prefill-Decode) router implementation for SGLang @@ -55,6 +50,10 @@ pub struct GrpcPDRouter { retry_config: RetryConfig, configured_reasoning_parser: Option, configured_tool_parser: Option, + // Pipeline for non-streaming requests + pipeline: super::pipeline::ChatCompletionPipeline, + // Shared components for pipeline + shared_components: Arc, } impl GrpcPDRouter { @@ -81,6 +80,39 @@ impl GrpcPDRouter { .ok_or_else(|| "gRPC PD router requires tool parser factory".to_string())? .clone(); + // Create shared components for pipeline + let shared_components = Arc::new(super::context::SharedComponents { + tokenizer: tokenizer.clone(), + tool_parser_factory: tool_parser_factory.clone(), + reasoning_parser_factory: reasoning_parser_factory.clone(), + }); + + // Create response processor + let processor = super::processing::ResponseProcessor::new( + tokenizer.clone(), + tool_parser_factory.clone(), + reasoning_parser_factory.clone(), + ctx.configured_tool_parser.clone(), + ctx.configured_reasoning_parser.clone(), + ); + + // Create streaming processor + let streaming_processor = Arc::new(super::streaming::StreamingProcessor::new( + tokenizer.clone(), + tool_parser_factory.clone(), + reasoning_parser_factory.clone(), + ctx.configured_tool_parser.clone(), + ctx.configured_reasoning_parser.clone(), + )); + + // Create PD pipeline + let pipeline = super::pipeline::ChatCompletionPipeline::new_pd( + worker_registry.clone(), + policy_registry.clone(), + processor, + streaming_processor, + ); + Ok(GrpcPDRouter { worker_registry, policy_registry, @@ -92,6 +124,8 @@ impl GrpcPDRouter { retry_config: ctx.router_config.effective_retry_config(), configured_reasoning_parser: ctx.configured_reasoning_parser.clone(), configured_tool_parser: ctx.configured_tool_parser.clone(), + pipeline, + shared_components, }) } @@ -314,7 +348,7 @@ impl GrpcPDRouter { /// Main route_chat implementation with PD dual dispatch async fn route_chat_impl( &self, - _headers: Option<&HeaderMap>, + headers: Option<&HeaderMap>, body: &ChatCompletionRequest, model_id: Option<&str>, ) -> Response { @@ -323,91 +357,15 @@ impl GrpcPDRouter { model_id ); - // Step 1: Filter tools if needed for allowed_tools or specific function - let body_ref = utils::filter_tools_for_request(body); - - // Step 2: Process messages and apply chat template - let processed_messages = match utils::process_chat_messages(&body_ref, &*self.tokenizer) { - Ok(msgs) => msgs, - Err(e) => { - return utils::bad_request_error(e.to_string()); - } - }; - - // Step 3: Tokenize the processed text - let encoding = match self.tokenizer.encode(&processed_messages.text) { - Ok(encoding) => encoding, - Err(e) => { - return utils::internal_error_message(format!("Tokenization failed: {}", e)); - } - }; - - // Step 4: Build tool constraints if needed - // body_ref already has filtered tools if needed - let tool_call_constraint = body_ref.tools.as_ref().and_then(|tools| { - utils::generate_tool_constraints(tools, &body.tool_choice, &body.model) - }); - - let token_ids = encoding.token_ids().to_vec(); - debug!("Tokenized {} tokens from input", token_ids.len()); - - // Step 5: Select prefill-decode worker pair - let (prefill_worker, decode_worker) = match self - .select_pd_pair(Some(&processed_messages.text), model_id) + // Use pipeline for ALL requests (streaming and non-streaming) + self.pipeline + .execute_chat( + body.clone(), + headers.cloned(), + model_id.map(|s| s.to_string()), + self.shared_components.clone(), + ) .await - { - Ok(pair) => pair, - Err(e) => { - return utils::service_unavailable_error(e); - } - }; - - debug!( - "Selected PD pair: prefill={}, decode={}", - prefill_worker.url(), - decode_worker.url() - ); - - // Step 6: Get gRPC clients for both workers - let prefill_client = match utils::get_grpc_client_from_worker(&prefill_worker).await { - Ok(client) => client, - Err(response) => return response, - }; - - let decode_client = match utils::get_grpc_client_from_worker(&decode_worker).await { - Ok(client) => client, - Err(response) => return response, - }; - - // Step 7: Build the base gRPC request - let request_id = format!("chatcmpl-{}", Uuid::new_v4()); - let mut request = match prefill_client.build_generate_request( - request_id.clone(), - &body_ref, - processed_messages.text.clone(), - token_ids, - processed_messages.multimodal_inputs, - tool_call_constraint, - ) { - Ok(request) => request, - Err(e) => { - return utils::bad_request_error(format!("Invalid request parameters: {}", e)); - } - }; - - // Step 8: Inject bootstrap metadata into the request - if let Err(e) = Self::inject_bootstrap_metadata(&mut request, &*prefill_worker) { - return utils::internal_error_message(e); - } - - // Step 9: Handle streaming vs non-streaming - if body.stream { - self.handle_streaming_chat(prefill_client, decode_client, request, body) - .await - } else { - self.handle_non_streaming_chat(prefill_client, decode_client, request, body) - .await - } } /// Resolve the generate input into optional original text and token IDs @@ -441,109 +399,6 @@ impl GrpcPDRouter { Err("Either `text` or `input_ids` must be provided".to_string()) } - /// Submit request and handle streaming response for chat completions (PD mode) - async fn handle_streaming_chat( - &self, - mut prefill_client: SglangSchedulerClient, - mut decode_client: SglangSchedulerClient, - request: proto::GenerateRequest, - original_request: &ChatCompletionRequest, - ) -> Response { - let request_id = request.request_id.clone(); - let model = original_request.model.clone(); - - // Create channel for SSE streaming - let (tx, rx) = unbounded_channel::>(); - - // Send requests in parallel to both prefill and decode workers - debug!("Starting concurrent streaming requests to prefill and decode workers"); - let prefill_request = request.clone(); - let decode_request = request; - - let (prefill_result, decode_result) = tokio::join!( - prefill_client.generate(prefill_request), - decode_client.generate(decode_request) - ); - - // Get prefill stream - let prefill_stream = match prefill_result { - Ok(s) => s, - Err(e) => { - return utils::internal_error_message(format!( - "Prefill worker failed to start: {}", - e - )); - } - }; - - // Get decode stream - this is what we'll process for output - let decode_stream = match decode_result { - Ok(s) => s, - Err(e) => { - return utils::internal_error_message(format!( - "Decode worker failed to start: {}", - e - )); - } - }; - - let stop_params = ( - original_request.stop.clone(), - original_request.stop_token_ids.clone(), - original_request.skip_special_tokens, - original_request.no_stop_trim, - ); - - // Spawn processing task for both streams - let self_clone = self.clone(); - let original_request_clone = original_request.clone(); - tokio::spawn(async move { - let result = Self::process_dual_streaming_chunks( - &self_clone, - prefill_stream, - decode_stream, - request_id, - model, - stop_params, - original_request_clone, - &tx, - ) - .await; - - if let Err(e) = result { - let error_chunk = format!( - "data: {}\n\n", - serde_json::json!({ - "error": { - "message": e, - "type": "internal_error" - } - }) - ); - let _ = tx.send(Ok(bytes::Bytes::from(error_chunk))); - } - - // Send DONE marker - let _ = tx.send(Ok(bytes::Bytes::from("data: [DONE]\n\n"))); - }); - - // Create response with SSE headers - let stream = tokio_stream::wrappers::UnboundedReceiverStream::new(rx); - let mut response = Response::new(Body::from_stream(stream)); - *response.status_mut() = StatusCode::OK; - response.headers_mut().insert( - header::CONTENT_TYPE, - HeaderValue::from_static("text/event-stream"), - ); - response - .headers_mut() - .insert("Cache-Control", HeaderValue::from_static("no-cache")); - response - .headers_mut() - .insert("Connection", HeaderValue::from_static("keep-alive")); - response - } - /// Submit request and handle streaming response for generate endpoint (PD mode) async fn handle_streaming_generate( &self, @@ -766,778 +621,6 @@ impl GrpcPDRouter { Ok(()) } - /// Process dual streaming chunks (prefill + decode) and send SSE events (PD mode) - #[allow(clippy::too_many_arguments)] - async fn process_dual_streaming_chunks( - router: &GrpcPDRouter, - mut prefill_stream: impl Stream> + Unpin, - mut decode_stream: impl Stream> + Unpin, - request_id: String, - model: String, - stop_params: (Option, Option>, bool, bool), - original_request: ChatCompletionRequest, - tx: &UnboundedSender>, - ) -> Result<(), String> { - // Extract request parameters - let separate_reasoning = original_request.separate_reasoning; - let tool_choice = &original_request.tool_choice; - let tools = &original_request.tools; - let history_tool_calls_count = utils::get_history_tool_calls_count(&original_request); - let stream_options = &original_request.stream_options; - - // Phase 1: Initialize state tracking (per-index for n>1 support) - let mut is_firsts: HashMap = HashMap::new(); - let mut stream_buffers: HashMap = HashMap::new(); - let mut finish_reasons: HashMap = HashMap::new(); - let mut matched_stops: HashMap> = HashMap::new(); - let mut prompt_tokens: HashMap = HashMap::new(); - let mut completion_tokens: HashMap = HashMap::new(); - let mut cached_tokens: HashMap = HashMap::new(); - - // Parser state (lazy initialization per index) - type PooledReasoningParser = Arc>>; - let mut reasoning_parsers: HashMap = HashMap::new(); - - type PooledToolParser = Arc>>; - let mut tool_parsers: HashMap = HashMap::new(); - let mut has_tool_calls: HashMap = HashMap::new(); - - // Create stop decoder - let (stop, stop_token_ids, skip_special_tokens, no_stop_trim) = stop_params; - let mut stop_decoder = utils::create_stop_decoder( - &router.tokenizer, - stop.as_ref(), - stop_token_ids.as_ref(), - skip_special_tokens, - no_stop_trim, - ); - - let created = SystemTime::now() - .duration_since(UNIX_EPOCH) - .unwrap_or_default() - .as_secs(); - - // Phase 1.5: Collect input_logprobs from prefill stream if requested - // Note: In PD mode, input_logprobs come from prefill worker - // TODO: Store and emit input_logprobs when implementing prompt logprobs in streaming - if original_request.logprobs { - while let Some(response) = prefill_stream.next().await { - let gen_response = response.map_err(|e| format!("Prefill stream error: {}", e))?; - match gen_response.response { - Some(Complete(_complete)) => { - // Input logprobs collected but not yet used in streaming - // (OpenAI spec doesn't require prompt logprobs in streaming responses) - break; - } - Some(Error(error)) => { - return Err(format!("Prefill error: {}", error.message)); - } - _ => continue, - } - } - } - - // Phase 2: Main streaming loop (decode stream) - while let Some(response) = decode_stream.next().await { - let gen_response = response.map_err(|e| format!("Stream error: {}", e))?; - - match gen_response.response { - Some(Chunk(chunk)) => { - let index = chunk.index; - - // Process tokens through stop decoder - let (chunk_text, _should_stop) = - Self::process_chunk_tokens(&mut stop_decoder, &chunk.token_ids); - - if chunk_text.is_empty() { - continue; - } - - // Process logprobs if present - let choice_logprobs = if let Some(ref proto_logprobs) = chunk.output_logprobs { - match router.convert_proto_to_openai_logprobs(proto_logprobs) { - Ok(logprobs) => Some(logprobs), - Err(e) => { - warn!("Failed to process logprobs: {}", e); - None - } - } - } else { - None - }; - - // Initialize stream buffer if first time - let stream_buffer = stream_buffers.entry(index).or_default(); - - // Send first chunk with role - if is_firsts.get(&index).copied().unwrap_or(true) { - let first_chunk = ChatCompletionStreamResponse { - id: request_id.clone(), - object: "chat.completion.chunk".to_string(), - created, - model: model.clone(), - system_fingerprint: None, - choices: vec![ChatStreamChoice { - index, - delta: ChatMessageDelta { - role: Some("assistant".to_string()), - content: None, - tool_calls: None, - reasoning_content: None, - }, - logprobs: None, - finish_reason: None, - matched_stop: None, - }], - usage: None, - }; - tx.send(Ok(bytes::Bytes::from(Self::format_sse_chunk(&first_chunk)))) - .map_err(|_| "Failed to send first chunk".to_string())?; - is_firsts.insert(index, false); - } - - // Calculate delta - let mut delta = chunk_text; - stream_buffer.push_str(&delta); - - // Reasoning content handling - let in_reasoning = if separate_reasoning { - let (normal_text, reasoning_chunk, in_reasoning) = router - .process_reasoning_stream( - &delta, - index, - &mut reasoning_parsers, - &request_id, - &model, - created, - ); - if let Some(chunk) = reasoning_chunk { - tx.send(Ok(bytes::Bytes::from(Self::format_sse_chunk(&chunk)))) - .map_err(|_| "Failed to send reasoning chunk".to_string())?; - } - delta = normal_text; - in_reasoning - } else { - false - }; - - // Tool call handling - let tool_choice_enabled = - !matches!(tool_choice, Some(ToolChoice::Value(ToolChoiceValue::None))); - - if !in_reasoning && tool_choice_enabled && tools.is_some() { - let (should_skip, tool_chunks) = router - .process_tool_calls_stream( - &delta, - index, - &mut tool_parsers, - &mut has_tool_calls, - tools.as_ref().unwrap(), - &request_id, - &model, - created, - history_tool_calls_count, - ) - .await; - - for chunk in tool_chunks { - tx.send(Ok(bytes::Bytes::from(Self::format_sse_chunk(&chunk)))) - .map_err(|_| "Failed to send tool call chunk".to_string())?; - } - - if should_skip { - continue; - } - } - - // Regular content emission - if !delta.is_empty() { - let content_chunk = Self::create_content_chunk( - delta, - index, - &request_id, - &model, - created, - choice_logprobs, - ); - tx.send(Ok(bytes::Bytes::from(Self::format_sse_chunk( - &content_chunk, - )))) - .map_err(|_| "Failed to send content chunk".to_string())?; - } - } - Some(Complete(complete)) => { - // Flush any remaining text - if let SequenceDecoderOutput::Text(text) = stop_decoder.flush() { - if !text.is_empty() { - let index = complete.index; - let stream_buffer = stream_buffers.entry(index).or_default(); - stream_buffer.push_str(&text); - - let content_chunk = ChatCompletionStreamResponse { - id: request_id.clone(), - object: "chat.completion.chunk".to_string(), - created, - model: model.clone(), - system_fingerprint: None, - choices: vec![ChatStreamChoice { - index, - delta: ChatMessageDelta { - role: Some("assistant".to_string()), - content: Some(text), - tool_calls: None, - reasoning_content: None, - }, - logprobs: None, - finish_reason: None, - matched_stop: None, - }], - usage: None, - }; - - let sse_chunk = serde_json::to_string(&content_chunk) - .map_err(|e| format!("Failed to serialize content chunk: {}", e))?; - tx.send(Ok(bytes::Bytes::from(format!("data: {}\n\n", sse_chunk)))) - .map_err(|_| "Failed to send flushed content".to_string())?; - } - } - - // Store metadata - let index = complete.index; - prompt_tokens.insert(index, complete.prompt_tokens as u32); - completion_tokens.insert(index, complete.completion_tokens as u32); - cached_tokens.insert(index, complete.cached_tokens as u32); - finish_reasons.insert(index, complete.finish_reason.clone()); - - // Extract matched_stop - let matched_stop_value = match &complete.matched_stop { - Some(proto::generate_complete::MatchedStop::MatchedTokenId(token_id)) => { - Some(Value::Number(serde_json::Number::from(*token_id))) - } - Some(proto::generate_complete::MatchedStop::MatchedStopStr(stop_str)) => { - Some(Value::String(stop_str.clone())) - } - None => None, - }; - matched_stops.insert(index, matched_stop_value); - - break; - } - Some(Error(error)) => { - return Err(error.message); - } - None => continue, - } - } - - // Phase 3: Check unstreamed tool args - for (index, parser) in &tool_parsers { - let parser_guard = parser.lock().await; - if let Some(unstreamed_items) = parser_guard.get_unstreamed_tool_args() { - for tool_call_item in unstreamed_items { - let tool_call_delta = ToolCallDelta { - index: tool_call_item.tool_index as u32, - id: None, - tool_type: None, - function: Some(FunctionCallDelta { - name: None, - arguments: if !tool_call_item.parameters.is_empty() { - Some(tool_call_item.parameters) - } else { - None - }, - }), - }; - - let tool_chunk = ChatCompletionStreamResponse { - id: request_id.clone(), - object: "chat.completion.chunk".to_string(), - created, - model: model.clone(), - system_fingerprint: None, - choices: vec![ChatStreamChoice { - index: *index, - delta: ChatMessageDelta { - role: Some("assistant".to_string()), - content: None, - tool_calls: Some(vec![tool_call_delta]), - reasoning_content: None, - }, - logprobs: None, - finish_reason: None, - matched_stop: None, - }], - usage: None, - }; - - let sse_chunk = serde_json::to_string(&tool_chunk) - .map_err(|e| format!("Failed to serialize tool chunk: {}", e))?; - tx.send(Ok(bytes::Bytes::from(format!("data: {}\n\n", sse_chunk)))) - .map_err(|_| "Failed to send unstreamed tool args".to_string())?; - } - } - } - - // Phase 4: Finish reason chunks - for (index, finish_reason) in finish_reasons.iter() { - let final_finish_reason = - if has_tool_calls.get(index).copied().unwrap_or(false) && finish_reason == "stop" { - "tool_calls".to_string() - } else { - finish_reason.clone() - }; - - let matched_stop_value = matched_stops.get(index).and_then(|v| v.clone()); - - let finish_chunk = ChatCompletionStreamResponse { - id: request_id.clone(), - object: "chat.completion.chunk".to_string(), - created, - model: model.clone(), - system_fingerprint: None, - choices: vec![ChatStreamChoice { - index: *index, - delta: ChatMessageDelta { - role: Some("assistant".to_string()), - content: None, - tool_calls: None, - reasoning_content: None, - }, - logprobs: None, - finish_reason: Some(final_finish_reason), - matched_stop: matched_stop_value, - }], - usage: None, - }; - - let sse_chunk = serde_json::to_string(&finish_chunk) - .map_err(|e| format!("Failed to serialize finish chunk: {}", e))?; - tx.send(Ok(bytes::Bytes::from(format!("data: {}\n\n", sse_chunk)))) - .map_err(|_| "Failed to send finish chunk".to_string())?; - } - - // Phase 5: Usage chunk - if let Some(stream_opts) = stream_options { - if stream_opts.include_usage.unwrap_or(false) { - let total_prompt: u32 = prompt_tokens.values().sum(); - let total_completion: u32 = completion_tokens.values().sum(); - - let usage_chunk = ChatCompletionStreamResponse { - id: request_id.clone(), - object: "chat.completion.chunk".to_string(), - created, - model: model.clone(), - system_fingerprint: None, - choices: vec![], - usage: Some(Usage { - prompt_tokens: total_prompt, - completion_tokens: total_completion, - total_tokens: total_prompt + total_completion, - completion_tokens_details: None, - }), - }; - - let sse_chunk = serde_json::to_string(&usage_chunk) - .map_err(|e| format!("Failed to serialize usage chunk: {}", e))?; - tx.send(Ok(bytes::Bytes::from(format!("data: {}\n\n", sse_chunk)))) - .map_err(|_| "Failed to send usage chunk".to_string())?; - } - } - - Ok(()) - } - - /// Helper: Process reasoning content in streaming mode - fn process_reasoning_stream( - &self, - delta: &str, - index: u32, - reasoning_parsers: &mut HashMap>>>, - request_id: &str, - model: &str, - created: u64, - ) -> (String, Option, bool) { - // Get or create parser for this index - reasoning_parsers.entry(index).or_insert_with(|| { - utils::get_reasoning_parser( - &self.reasoning_parser_factory, - self.configured_reasoning_parser.as_ref(), - model, - ) - }); - - if let Some(pooled_parser) = reasoning_parsers.get(&index) { - let (parse_result, in_reasoning) = { - let mut parser = pooled_parser.lock().unwrap(); - let result = parser.parse_reasoning_streaming_incremental(delta); - let in_reasoning = parser.is_in_reasoning(); - (result, in_reasoning) - }; - - match parse_result { - Ok(ParserResult { - reasoning_text, - normal_text, - }) => { - let chunk = if !reasoning_text.is_empty() { - Some(ChatCompletionStreamResponse { - id: request_id.to_string(), - object: "chat.completion.chunk".to_string(), - created, - model: model.to_string(), - system_fingerprint: None, - choices: vec![ChatStreamChoice { - index, - delta: ChatMessageDelta { - role: Some("assistant".to_string()), - content: None, - tool_calls: None, - reasoning_content: Some(reasoning_text), - }, - logprobs: None, - finish_reason: None, - matched_stop: None, - }], - usage: None, - }) - } else { - None - }; - return (normal_text, chunk, in_reasoning); - } - Err(e) => { - warn!("Reasoning parsing error: {}", e); - } - } - } - - (delta.to_string(), None, false) - } - - /// Helper: Process tool calls in streaming mode - #[allow(clippy::too_many_arguments)] - async fn process_tool_calls_stream( - &self, - delta: &str, - index: u32, - tool_parsers: &mut HashMap>>>, - has_tool_calls: &mut HashMap, - tools: &[Tool], - request_id: &str, - model: &str, - created: u64, - history_tool_calls_count: usize, - ) -> (bool, Vec) { - let mut chunks = Vec::new(); - - // Get or create parser for this index - tool_parsers.entry(index).or_insert_with(|| { - utils::get_tool_parser( - &self.tool_parser_factory, - self.configured_tool_parser.as_ref(), - model, - ) - }); - - if let Some(pooled_parser) = tool_parsers.get(&index) { - let mut parser = pooled_parser.lock().await; - match parser.parse_incremental(delta, tools).await { - Ok(StreamingParseResult { normal_text, calls }) => { - // Emit normal text if present - if !normal_text.is_empty() { - chunks.push(ChatCompletionStreamResponse { - id: request_id.to_string(), - object: "chat.completion.chunk".to_string(), - created, - model: model.to_string(), - system_fingerprint: None, - choices: vec![ChatStreamChoice { - index, - delta: ChatMessageDelta { - role: Some("assistant".to_string()), - content: Some(normal_text), - tool_calls: None, - reasoning_content: None, - }, - logprobs: None, - finish_reason: None, - matched_stop: None, - }], - usage: None, - }); - } - - // Emit tool call chunks - for tool_call_item in calls { - has_tool_calls.insert(index, true); - - let tool_call_id = if let Some(ref name) = tool_call_item.name { - Some(utils::generate_tool_call_id( - model, - name, - tool_call_item.tool_index, - history_tool_calls_count, - )) - } else { - None - }; - - let tool_call_delta = ToolCallDelta { - index: tool_call_item.tool_index as u32, - id: tool_call_id, - tool_type: if tool_call_item.name.is_some() { - Some("function".to_string()) - } else { - None - }, - function: Some(FunctionCallDelta { - name: tool_call_item.name, - arguments: if !tool_call_item.parameters.is_empty() { - Some(tool_call_item.parameters) - } else { - None - }, - }), - }; - - chunks.push(ChatCompletionStreamResponse { - id: request_id.to_string(), - object: "chat.completion.chunk".to_string(), - created, - model: model.to_string(), - system_fingerprint: None, - choices: vec![ChatStreamChoice { - index, - delta: ChatMessageDelta { - role: Some("assistant".to_string()), - content: None, - tool_calls: Some(vec![tool_call_delta]), - reasoning_content: None, - }, - logprobs: None, - finish_reason: None, - matched_stop: None, - }], - usage: None, - }); - } - - // If we emitted chunks, skip regular content - return (!chunks.is_empty(), chunks); - } - Err(e) => { - warn!("Tool call parsing error: {}", e); - } - } - } - - (false, chunks) - } - - /// Helper: Create content chunk - fn create_content_chunk( - content: String, - index: u32, - request_id: &str, - model: &str, - created: u64, - logprobs: Option, - ) -> ChatCompletionStreamResponse { - ChatCompletionStreamResponse { - id: request_id.to_string(), - object: "chat.completion.chunk".to_string(), - created, - model: model.to_string(), - system_fingerprint: None, - choices: vec![ChatStreamChoice { - index, - delta: ChatMessageDelta { - role: Some("assistant".to_string()), - content: Some(content), - tool_calls: None, - reasoning_content: None, - }, - logprobs, - finish_reason: None, - matched_stop: None, - }], - usage: None, - } - } - - /// Helper: Format response as SSE chunk - fn format_sse_chunk(response: &ChatCompletionStreamResponse) -> String { - format!( - "data: {}\n\n", - serde_json::to_string(response).unwrap_or_default() - ) - } - - /// Process a chunk of tokens through the stop decoder - fn process_chunk_tokens( - stop_decoder: &mut StopSequenceDecoder, - token_ids: &[u32], - ) -> (String, bool) { - let mut chunk_text = String::new(); - - for &token_id in token_ids { - match stop_decoder.process_token(token_id).unwrap_or_else(|e| { - debug!( - "Error processing token {}: {}. Treating as Held.", - token_id, e - ); - SequenceDecoderOutput::Held - }) { - SequenceDecoderOutput::Text(text) => { - chunk_text.push_str(&text); - } - SequenceDecoderOutput::StoppedWithText(text) => { - chunk_text.push_str(&text); - return (chunk_text, true); - } - SequenceDecoderOutput::Stopped => { - return (chunk_text, true); - } - SequenceDecoderOutput::Held => {} - } - } - (chunk_text, false) - } - - /// Submit request and handle non-streaming response for chat completions (PD mode) - async fn handle_non_streaming_chat( - &self, - mut prefill_client: SglangSchedulerClient, - mut decode_client: SglangSchedulerClient, - request: proto::GenerateRequest, - original_request: &ChatCompletionRequest, - ) -> Response { - // Step 1: Create stop decoder - let mut stop_decoder = utils::create_stop_decoder( - &self.tokenizer, - original_request.stop.as_ref(), - original_request.stop_token_ids.as_ref(), - original_request.skip_special_tokens, - original_request.no_stop_trim, - ); - - // Step 2: Send requests in parallel - debug!("Sending concurrent requests to prefill and decode workers"); - let prefill_request = request.clone(); - let decode_request = request; - - let (prefill_result, decode_result) = tokio::join!( - prefill_client.generate(prefill_request), - decode_client.generate(decode_request) - ); - - // Step 3: Process prefill stream in parallel - if it fails, assume decode fails - let prefill_stream = match prefill_result { - Ok(s) => s, - Err(e) => { - error!("Failed to start prefill generation: {}", e); - return utils::internal_error_message(format!( - "Prefill worker failed to start: {}", - e - )); - } - }; - - let decode_stream = match decode_result { - Ok(s) => s, - Err(e) => { - error!("Failed to start decode generation: {}", e); - return utils::internal_error_message(format!( - "Decode worker failed to start: {}", - e - )); - } - }; - - // Collect prefill response (for input_logprobs if requested) - let prefill_responses = - match utils::collect_stream_responses(prefill_stream, "Prefill").await { - Ok(responses) => responses, - Err(error_response) => return error_response, - }; - - // Extract input_logprobs from prefill response if available - let prefill_input_logprobs = prefill_responses - .first() - .and_then(|r| r.input_logprobs.clone()); - - // Step 4: Process decode stream (collect all responses for n>1 support) - let all_responses = match utils::collect_stream_responses(decode_stream, "Decode").await { - Ok(responses) => responses, - Err(error_response) => return error_response, - }; - - if all_responses.is_empty() { - return utils::internal_error_static("No responses from decode worker"); - } - - // Process each response into a ChatChoice - let history_tool_calls_count = utils::get_history_tool_calls_count(original_request); - let mut choices = Vec::new(); - for (index, complete) in all_responses.iter().enumerate() { - // Merge prefill input_logprobs if available and requested - let mut complete_with_logprobs = complete.clone(); - if prefill_input_logprobs.is_some() && original_request.logprobs { - complete_with_logprobs.input_logprobs = prefill_input_logprobs.clone(); - } - - match self - .process_single_choice( - &complete_with_logprobs, - index, - original_request, - &mut stop_decoder, - history_tool_calls_count, - ) - .await - { - Ok(choice) => choices.push(choice), - Err(e) => { - return utils::internal_error_message(format!( - "Failed to process choice {}: {}", - index, e - )); - } - } - } - - // Aggregate usage information from all responses - let total_prompt_tokens: u32 = all_responses.iter().map(|r| r.prompt_tokens as u32).sum(); - let total_completion_tokens: u32 = all_responses - .iter() - .map(|r| r.completion_tokens as u32) - .sum(); - - let usage = Usage { - prompt_tokens: total_prompt_tokens, - completion_tokens: total_completion_tokens, - total_tokens: total_prompt_tokens + total_completion_tokens, - completion_tokens_details: None, - }; - - // Build final ChatCompletionResponse - let response = ChatCompletionResponse { - id: format!("chatcmpl-{}", Uuid::new_v4()), - object: "chat.completion".to_string(), - created: SystemTime::now() - .duration_since(UNIX_EPOCH) - .unwrap_or_default() - .as_secs(), - model: original_request.model.clone(), - choices, - usage: Some(usage), - system_fingerprint: None, - }; - - // Serialize and return JSON response - Json(response).into_response() - } - /// Submit request and handle non-streaming response for generate endpoint (PD mode) async fn handle_non_streaming_generate( &self, @@ -1683,301 +766,6 @@ impl GrpcPDRouter { Json(result_array).into_response() } - - /// Process a single GenerateComplete response into a ChatChoice - async fn process_single_choice( - &self, - complete: &proto::GenerateComplete, - index: usize, - original_request: &ChatCompletionRequest, - stop_decoder: &mut StopSequenceDecoder, - history_tool_calls_count: usize, - ) -> Result { - stop_decoder.reset(); - // Decode tokens - let outputs = stop_decoder - .process_tokens(&complete.output_ids) - .map_err(|e| format!("Failed to process tokens: {}", e))?; - - // Accumulate text with early breaks - let mut final_text = String::new(); - for output in outputs { - match output { - SequenceDecoderOutput::Text(t) => final_text.push_str(&t), - SequenceDecoderOutput::StoppedWithText(t) => { - final_text.push_str(&t); - break; - } - SequenceDecoderOutput::Stopped => break, - SequenceDecoderOutput::Held => {} - } - } - - // Flush remaining text - if let SequenceDecoderOutput::Text(t) = stop_decoder.flush() { - final_text.push_str(&t); - } - - // Step 1: Handle reasoning content parsing - let mut reasoning_text: Option = None; - let mut processed_text = final_text; - - // Check if reasoning parsing is enabled and separate_reasoning is requested - if original_request.separate_reasoning { - let pooled_parser = utils::get_reasoning_parser( - &self.reasoning_parser_factory, - self.configured_reasoning_parser.as_ref(), - &original_request.model, - ); - - let mut parser = pooled_parser - .lock() - .map_err(|e| format!("Failed to acquire reasoning parser lock: {}", e))?; - match parser.detect_and_parse_reasoning(&processed_text) { - Ok(result) => { - if !result.reasoning_text.is_empty() { - reasoning_text = Some(result.reasoning_text); - } - processed_text = result.normal_text; - } - Err(e) => { - return Err(format!("Reasoning parsing error: {}", e)); - } - } - } - - // Step 2: Handle tool call parsing - let mut tool_calls: Option> = None; - - // Check if tool calls should be processed - let tool_choice_enabled = !matches!( - &original_request.tool_choice, - Some(ToolChoice::Value(ToolChoiceValue::None)) - ); - - if tool_choice_enabled && original_request.tools.is_some() { - // Check if JSON schema constraint was used (specific function or required mode) - let used_json_schema = match &original_request.tool_choice { - Some(ToolChoice::Function { .. }) => true, - Some(ToolChoice::Value(ToolChoiceValue::Required)) => true, - Some(ToolChoice::AllowedTools { mode, .. }) => mode == "required", - _ => false, - }; - - if used_json_schema { - (tool_calls, processed_text) = utils::parse_json_schema_response( - &processed_text, - &original_request.tool_choice, - ); - } else { - (tool_calls, processed_text) = self - .parse_tool_calls( - &processed_text, - &original_request.model, - history_tool_calls_count, - ) - .await; - } - } - - // Step 3: Use finish reason directly from proto (already OpenAI-compatible string) - let finish_reason_str = &complete.finish_reason; - - // Override finish reason if we have tool calls - let final_finish_reason_str = if tool_calls.is_some() { - "tool_calls" - } else { - finish_reason_str - }; - - // Extract matched_stop information from proto - let matched_stop = match &complete.matched_stop { - Some(proto::generate_complete::MatchedStop::MatchedTokenId(token_id)) => { - Some(Value::Number(serde_json::Number::from(*token_id))) - } - Some(proto::generate_complete::MatchedStop::MatchedStopStr(stop_str)) => { - Some(Value::String(stop_str.clone())) - } - None => None, - }; - - // Step 4: Convert output logprobs if present - // Note: complete.input_logprobs exists in proto but is not used for chat completions - // (input logprobs are only used in /v1/completions endpoint with echo=true) - let logprobs = if let Some(proto_logprobs) = &complete.output_logprobs { - match self.convert_proto_to_openai_logprobs(proto_logprobs) { - Ok(logprobs) => Some(logprobs), - Err(e) => { - error!("Failed to convert logprobs: {}", e); - None - } - } - } else { - None - }; - - // Step 5: Build ChatCompletionMessage (proper response message type) - let chat_message = ChatCompletionMessage { - role: "assistant".to_string(), - content: if processed_text.is_empty() { - None - } else { - Some(processed_text) - }, - tool_calls, - reasoning_content: reasoning_text, - }; - - // Step 6: Build ChatChoice - let choice = ChatChoice { - index: index as u32, - message: chat_message, - logprobs, - finish_reason: Some(final_finish_reason_str.to_string()), - matched_stop, - hidden_states: None, - }; - - Ok(choice) - } - - /// Parse tool calls using model-specific parser - async fn parse_tool_calls( - &self, - processed_text: &str, - model: &str, - history_tool_calls_count: usize, - ) -> (Option>, String) { - // Get pooled parser for this model - let pooled_parser = utils::get_tool_parser( - &self.tool_parser_factory, - self.configured_tool_parser.as_ref(), - model, - ); - - // Check format detection first - let can_parse = { - let parser = pooled_parser.lock().await; - parser.has_tool_markers(processed_text) - // Lock is dropped here - }; - - if !can_parse { - return (None, processed_text.to_string()); - } - - // Lock again for async parsing - let result = { - let parser = pooled_parser.lock().await; - parser.parse_complete(processed_text).await - // Lock is dropped here - }; - - match result { - Ok((normal_text, parsed_tool_calls)) => { - if parsed_tool_calls.is_empty() { - return (None, normal_text); - } - - let spec_tool_calls = parsed_tool_calls - .into_iter() - .enumerate() - .map(|(index, tc)| { - // Generate ID for this tool call - let id = utils::generate_tool_call_id( - model, - &tc.function.name, - index, - history_tool_calls_count, - ); - ToolCall { - id, - tool_type: "function".to_string(), - function: FunctionCallResponse { - name: tc.function.name, - arguments: Some( - serde_json::to_string(&tc.function.arguments) - .unwrap_or_else(|_| "{}".to_string()), - ), - }, - } - }) - .collect(); - (Some(spec_tool_calls), normal_text) - } - Err(e) => { - error!("Tool call parsing error: {}", e); - (None, processed_text.to_string()) - } - } - } - - /// Convert proto LogProbs to OpenAI ChatLogProbs format - /// Note: Always decodes with skip_special_tokens=false to show actual tokens generated - fn convert_proto_to_openai_logprobs( - &self, - proto_logprobs: &proto::OutputLogProbs, - ) -> Result { - let mut content_items = Vec::new(); - - // Decode token IDs to text (always with skip_special_tokens=false for logprobs) - let token_texts: Vec = proto_logprobs - .token_ids - .iter() - .map(|&token_id| { - self.tokenizer - .decode(&[token_id as u32], false) - .unwrap_or_else(|_| format!("", token_id)) - }) - .collect(); - - // Build ChatLogProbsContent for each token - for (i, &logprob) in proto_logprobs.token_logprobs.iter().enumerate() { - let token_text = token_texts.get(i).cloned().unwrap_or_default(); - let bytes = Some(token_text.as_bytes().to_vec()); - - // Build top_logprobs for this position - let mut top_logprobs = Vec::new(); - if let Some(top_logprobs_entry) = proto_logprobs.top_logprobs.get(i) { - // Decode top token IDs (always with skip_special_tokens=false) - let top_token_texts: Vec = top_logprobs_entry - .token_ids - .iter() - .map(|&tid| { - self.tokenizer - .decode(&[tid as u32], false) - .unwrap_or_else(|_| format!("", tid)) - }) - .collect(); - - for (j, (&top_logprob, &_top_token_id)) in top_logprobs_entry - .values - .iter() - .zip(top_logprobs_entry.token_ids.iter()) - .enumerate() - { - if let Some(top_token_text) = top_token_texts.get(j) { - top_logprobs.push(TopLogProb { - token: top_token_text.clone(), - logprob: top_logprob, - bytes: Some(top_token_text.as_bytes().to_vec()), - }); - } - } - } - - content_items.push(ChatLogProbsContent { - token: token_text, - logprob, - bytes, - top_logprobs, - }); - } - - Ok(ChatLogProbs::Detailed { - content: (!content_items.is_empty()).then_some(content_items), - }) - } } impl std::fmt::Debug for GrpcPDRouter { diff --git a/sgl-router/src/routers/grpc/pipeline.rs b/sgl-router/src/routers/grpc/pipeline.rs new file mode 100644 index 000000000..380be569e --- /dev/null +++ b/sgl-router/src/routers/grpc/pipeline.rs @@ -0,0 +1,1146 @@ +//! Pipeline stages for gRPC router request processing +//! +//! This module defines the core pipeline abstraction and individual processing stages +//! that transform a RequestContext through its lifecycle. + +use async_trait::async_trait; +use axum::response::{IntoResponse, Response}; +use tracing::{debug, error, warn}; + +use super::context::*; +use super::processing; +use super::streaming; +use super::utils; +use crate::core::{ConnectionMode, WorkerRegistry, WorkerType}; +use crate::grpc_client::proto; +use crate::policies::PolicyRegistry; +use crate::protocols::spec::{ + ChatCompletionRequest, ChatCompletionResponse, GenerateRequest, InputIds, Usage, +}; +use rand::Rng; +use std::sync::Arc; +use std::time::{SystemTime, UNIX_EPOCH}; +use uuid::Uuid; + +// ============================================================================ +// Pipeline Trait +// ============================================================================ + +/// Trait for pipeline stages that process requests +#[async_trait] +pub trait PipelineStage: Send + Sync { + /// Execute this stage, mutating the context + /// + /// Returns: + /// - `Ok(None)` - Continue to next stage + /// - `Ok(Some(response))` - Pipeline complete, return this response (e.g., streaming) + /// - `Err(response)` - Error occurred, return this error response + async fn execute(&self, ctx: &mut RequestContext) -> Result, Response>; + + /// Stage name for logging + fn name(&self) -> &'static str; +} + +// ============================================================================ +// Stage 1: Preparation +// ============================================================================ + +/// Preparation stage: Filter tools, process messages, tokenize, build constraints +pub struct PreparationStage; + +#[async_trait] +impl PipelineStage for PreparationStage { + async fn execute(&self, ctx: &mut RequestContext) -> Result, Response> { + debug!("Stage {}: Processing request", self.name()); + + // Clone the request to avoid borrowing issues + match &ctx.input.request_type { + RequestType::Chat(request) => { + let request_clone = request.clone(); + self.prepare_chat(ctx, &request_clone).await?; + } + RequestType::Generate(request) => { + let request_clone = request.clone(); + self.prepare_generate(ctx, &request_clone).await?; + } + } + + Ok(None) + } + + fn name(&self) -> &'static str { + "Preparation" + } +} + +impl PreparationStage { + async fn prepare_chat( + &self, + ctx: &mut RequestContext, + request: &ChatCompletionRequest, + ) -> Result<(), Response> { + // Step 1: Filter tools if needed + let body_ref = utils::filter_tools_for_request(request); + + // Step 2: Process messages and apply chat template + let processed_messages = + match utils::process_chat_messages(&body_ref, &*ctx.components.tokenizer) { + Ok(msgs) => msgs, + Err(e) => { + return Err(utils::bad_request_error(e)); + } + }; + + // Step 3: Tokenize the processed text + let encoding = match ctx.components.tokenizer.encode(&processed_messages.text) { + Ok(encoding) => encoding, + Err(e) => { + return Err(utils::internal_error_message(format!( + "Tokenization failed: {}", + e + ))); + } + }; + + let token_ids = encoding.token_ids().to_vec(); + debug!("Tokenized {} tokens from input", token_ids.len()); + + // Step 4: Build tool constraints if needed + let tool_call_constraint = body_ref.tools.as_ref().and_then(|tools| { + utils::generate_tool_constraints(tools, &request.tool_choice, &request.model) + }); + + // Step 5: Create stop sequence decoder (build once, reuse in non-stream) + let stop_decoder = utils::create_stop_decoder( + &ctx.components.tokenizer, + request.stop.as_ref(), + request.stop_token_ids.as_ref(), + request.skip_special_tokens, + request.no_stop_trim, + ); + + // Store results in context + ctx.state.preparation = Some(PreparationOutput { + original_text: Some(processed_messages.text.clone()), + token_ids, + processed_messages: Some(processed_messages), + tool_constraints: tool_call_constraint, + filtered_request: if matches!(body_ref, std::borrow::Cow::Owned(_)) { + Some(body_ref.into_owned()) + } else { + None + }, + }); + + // Store stop decoder for reuse in response processing + ctx.state.response.stop_decoder = Some(stop_decoder); + + Ok(()) + } + + async fn prepare_generate( + &self, + ctx: &mut RequestContext, + request: &GenerateRequest, + ) -> Result<(), Response> { + // Resolve input (text, prompt, or input_ids) + let (original_text, token_ids) = match self.resolve_generate_input(ctx, request) { + Ok(res) => res, + Err(msg) => { + return Err(utils::bad_request_error(msg)); + } + }; + + debug!("Resolved input with {} tokens", token_ids.len()); + + // Create stop sequence decoder for generate requests + let params = request.sampling_params.as_ref(); + let stop_decoder = utils::create_stop_decoder( + &ctx.components.tokenizer, + params.and_then(|p| p.stop.as_ref()), + params.and_then(|p| p.stop_token_ids.as_ref()), + params.and_then(|p| p.skip_special_tokens).unwrap_or(true), + params.and_then(|p| p.no_stop_trim).unwrap_or(false), + ); + + ctx.state.preparation = Some(PreparationOutput { + original_text, + token_ids, + processed_messages: None, + tool_constraints: None, + filtered_request: None, + }); + + // Store stop decoder + ctx.state.response.stop_decoder = Some(stop_decoder); + + Ok(()) + } + + fn resolve_generate_input( + &self, + ctx: &RequestContext, + request: &GenerateRequest, + ) -> Result<(Option, Vec), String> { + if let Some(text) = &request.text { + return self + .tokenize_single_text(&ctx.components.tokenizer, text) + .map(|(original, ids)| (Some(original), ids)); + } + + // Handle input_ids - validate and convert + if let Some(input_ids) = &request.input_ids { + return match input_ids { + InputIds::Single(ids) => ids + .iter() + .map(|&id| u32::try_from(id)) + .collect::, _>>() + .map(|converted| (None, converted)) + .map_err(|_| "input_ids must be non-negative".to_string()), + InputIds::Batch(_) => { + Err("Batch input_ids are not supported over gRPC generate yet".to_string()) + } + }; + } + + Err("Either `text` or `input_ids` must be provided".to_string()) + } + + fn tokenize_single_text( + &self, + tokenizer: &Arc, + text: &str, + ) -> Result<(String, Vec), String> { + let encoding = tokenizer + .encode(text) + .map_err(|e| format!("Tokenization failed: {}", e))?; + Ok((text.to_string(), encoding.token_ids().to_vec())) + } +} + +// ============================================================================ +// Stage 2: Worker Selection +// ============================================================================ + +/// Worker selection stage: Select appropriate worker(s) based on routing mode +pub struct WorkerSelectionStage { + worker_registry: Arc, + policy_registry: Arc, + mode: WorkerSelectionMode, +} + +pub enum WorkerSelectionMode { + /// Regular mode: select single worker + Regular, + /// PD mode: select prefill + decode workers + PrefillDecode, +} + +impl WorkerSelectionStage { + pub fn new( + worker_registry: Arc, + policy_registry: Arc, + mode: WorkerSelectionMode, + ) -> Self { + Self { + worker_registry, + policy_registry, + mode, + } + } +} + +#[async_trait] +impl PipelineStage for WorkerSelectionStage { + async fn execute(&self, ctx: &mut RequestContext) -> Result, Response> { + debug!("Stage {}: Selecting workers", self.name()); + + let prep = ctx + .state + .preparation + .as_ref() + .ok_or_else(|| utils::internal_error_static("Preparation stage not completed"))?; + + let text = prep.original_text.as_deref(); + + let workers = match self.mode { + WorkerSelectionMode::Regular => { + match self.select_single_worker(ctx.input.model_id.as_deref(), text) { + Some(w) => WorkerSelection::Single { worker: w }, + None => { + return Err(utils::service_unavailable_error(format!( + "No available workers for model: {:?}", + ctx.input.model_id + ))); + } + } + } + WorkerSelectionMode::PrefillDecode => { + match self.select_pd_pair(ctx.input.model_id.as_deref(), text) { + Some((prefill, decode)) => WorkerSelection::Dual { prefill, decode }, + None => { + return Err(utils::service_unavailable_error(format!( + "No available PD worker pairs for model: {:?}", + ctx.input.model_id + ))); + } + } + } + }; + + ctx.state.workers = Some(workers); + Ok(None) + } + + fn name(&self) -> &'static str { + "WorkerSelection" + } +} + +impl WorkerSelectionStage { + fn select_single_worker( + &self, + model_id: Option<&str>, + text: Option<&str>, + ) -> Option> { + // Get workers for the specified model, filtered by connection mode + let workers = self.worker_registry.get_workers_filtered( + model_id, + Some(WorkerType::Regular), + Some(ConnectionMode::Grpc { port: None }), + false, // get all workers, we'll filter by is_available() next + ); + + // Filter by availability (health + circuit breaker) + let available: Vec> = workers + .iter() + .filter(|w| w.is_available()) + .cloned() + .collect(); + + if available.is_empty() { + return None; + } + + // Get the appropriate policy for this model + let policy = match model_id { + Some(model) => self.policy_registry.get_policy_or_default(model), + None => self.policy_registry.get_default_policy(), + }; + + // Select worker using the policy + let idx = policy.select_worker(&available, text)?; + Some(available[idx].clone()) + } + + fn select_pd_pair( + &self, + model_id: Option<&str>, + text: Option<&str>, + ) -> Option<(Arc, Arc)> { + // Get prefill workers - use None for WorkerType filter to get all types, + // then filter manually (since Prefill is a struct variant) + let all_workers = self.worker_registry.get_workers_filtered( + model_id, + None, // Get all types + Some(ConnectionMode::Grpc { port: None }), + false, + ); + + let prefill_workers: Vec<_> = all_workers + .iter() + .filter(|w| matches!(w.metadata().worker_type, WorkerType::Prefill { .. })) + .cloned() + .collect(); + + let available_prefill: Vec<_> = prefill_workers + .iter() + .filter(|w| w.is_available()) + .cloned() + .collect(); + + if available_prefill.is_empty() { + warn!("No available prefill workers"); + return None; + } + + // Get decode workers from the same all_workers list + let decode_workers: Vec<_> = all_workers + .iter() + .filter(|w| matches!(w.metadata().worker_type, WorkerType::Decode)) + .cloned() + .collect(); + + let available_decode: Vec<_> = decode_workers + .iter() + .filter(|w| w.is_available()) + .cloned() + .collect(); + + if available_decode.is_empty() { + warn!("No available decode workers"); + return None; + } + + // Select using policies + let policy = match model_id { + Some(model) => self.policy_registry.get_policy_or_default(model), + None => self.policy_registry.get_default_policy(), + }; + + let prefill_idx = policy.select_worker(&available_prefill, text)?; + let decode_idx = policy.select_worker(&available_decode, text)?; + + Some(( + available_prefill[prefill_idx].clone(), + available_decode[decode_idx].clone(), + )) + } +} + +// ============================================================================ +// Stage 3: Client Acquisition +// ============================================================================ + +/// Client acquisition stage: Get gRPC clients from selected workers +pub struct ClientAcquisitionStage; + +#[async_trait] +impl PipelineStage for ClientAcquisitionStage { + async fn execute(&self, ctx: &mut RequestContext) -> Result, Response> { + debug!("Stage {}: Acquiring gRPC clients", self.name()); + + let workers = ctx + .state + .workers + .as_ref() + .ok_or_else(|| utils::internal_error_static("Worker selection not completed"))?; + + let clients = match workers { + WorkerSelection::Single { worker } => { + let client = utils::get_grpc_client_from_worker(worker).await?; + ClientSelection::Single { client } + } + WorkerSelection::Dual { prefill, decode } => { + let prefill_client = utils::get_grpc_client_from_worker(prefill).await?; + let decode_client = utils::get_grpc_client_from_worker(decode).await?; + ClientSelection::Dual { + prefill: prefill_client, + decode: decode_client, + } + } + }; + + ctx.state.clients = Some(clients); + Ok(None) + } + + fn name(&self) -> &'static str { + "ClientAcquisition" + } +} + +// ============================================================================ +// Stage 4: Request Building +// ============================================================================ + +/// Request building stage: Build proto GenerateRequest +pub struct RequestBuildingStage { + inject_pd_metadata: bool, +} + +impl RequestBuildingStage { + pub fn new(inject_pd_metadata: bool) -> Self { + Self { inject_pd_metadata } + } +} + +#[async_trait] +impl PipelineStage for RequestBuildingStage { + async fn execute(&self, ctx: &mut RequestContext) -> Result, Response> { + debug!("Stage {}: Building proto request", self.name()); + + let prep = ctx + .state + .preparation + .as_ref() + .ok_or_else(|| utils::internal_error_static("Preparation not completed"))?; + + let clients = ctx + .state + .clients + .as_ref() + .ok_or_else(|| utils::internal_error_static("Client acquisition not completed"))?; + + // Get client for building request (use prefill client if PD mode) + let builder_client = match clients { + ClientSelection::Single { client } => client, + ClientSelection::Dual { prefill, .. } => prefill, + }; + + let mut proto_request = match &ctx.input.request_type { + RequestType::Chat(request) => { + let request_id = format!("chatcmpl-{}", Uuid::new_v4()); + let body_ref = prep.filtered_request.as_ref().unwrap_or(request); + + builder_client + .build_generate_request( + request_id, + body_ref, + prep.processed_messages.as_ref().unwrap().text.clone(), + prep.token_ids.clone(), + prep.processed_messages + .as_ref() + .unwrap() + .multimodal_inputs + .clone(), + prep.tool_constraints.clone(), + ) + .map_err(|e| { + utils::bad_request_error(format!("Invalid request parameters: {}", e)) + })? + } + RequestType::Generate(request) => { + let request_id = request + .rid + .clone() + .unwrap_or_else(|| format!("gen-{}", Uuid::new_v4())); + + builder_client + .build_plain_generate_request( + request_id, + request, + prep.original_text.clone(), + prep.token_ids.clone(), + ) + .map_err(utils::bad_request_error)? + } + }; + + // Inject PD metadata if needed + if self.inject_pd_metadata { + if let WorkerSelection::Dual { prefill, .. } = ctx.state.workers.as_ref().unwrap() { + self.inject_bootstrap_metadata(&mut proto_request, prefill); + } + } + + ctx.state.proto_request = Some(proto_request); + Ok(None) + } + + fn name(&self) -> &'static str { + "RequestBuilding" + } +} + +impl RequestBuildingStage { + fn inject_bootstrap_metadata( + &self, + request: &mut proto::GenerateRequest, + prefill_worker: &Arc, + ) { + use proto::DisaggregatedParams; + + let hostname = prefill_worker.bootstrap_host(); + let bootstrap_port = prefill_worker.bootstrap_port().unwrap_or(8998); + + // Generate room ID for bootstrap + let room_id = rand::rng().random_range(0..i32::MAX); + + // Create DisaggregatedParams + let disagg_params = DisaggregatedParams { + bootstrap_host: hostname.to_string(), + bootstrap_port: bootstrap_port as i32, + bootstrap_room: room_id, + }; + + // Inject metadata directly into request + request.disaggregated_params = Some(disagg_params); + + debug!( + "Injected bootstrap metadata: host={}, port={}, room={}", + hostname, bootstrap_port, room_id + ); + } +} + +// ============================================================================ +// Stage 5: Dispatch Metadata +// ============================================================================ + +/// Dispatch metadata stage: Prepare metadata for dispatch +pub struct DispatchMetadataStage; + +#[async_trait] +impl PipelineStage for DispatchMetadataStage { + async fn execute(&self, ctx: &mut RequestContext) -> Result, Response> { + debug!("Stage {}: Preparing dispatch metadata", self.name()); + + let proto_request = ctx + .state + .proto_request + .as_ref() + .ok_or_else(|| utils::internal_error_static("Proto request not built"))?; + + let request_id = proto_request.request_id.clone(); + let model = match &ctx.input.request_type { + RequestType::Chat(req) => req.model.clone(), + RequestType::Generate(_req) => { + // Generate requests don't have a model field + // Use model_id from input or default + ctx.input + .model_id + .clone() + .unwrap_or_else(|| "default".to_string()) + } + }; + + let weight_version = ctx + .state + .workers + .as_ref() + .map(|w| match w { + WorkerSelection::Single { worker } => worker, + WorkerSelection::Dual { decode, .. } => decode, + }) + .and_then(|w| w.metadata().labels.get("weight_version").cloned()) + .unwrap_or_else(|| "default".to_string()); + + let created = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap_or_default() + .as_secs(); + + ctx.state.dispatch = Some(DispatchMetadata { + request_id, + model, + created, + weight_version: Some(weight_version), + is_streaming: ctx.is_streaming(), + }); + + Ok(None) + } + + fn name(&self) -> &'static str { + "DispatchMetadata" + } +} + +// ============================================================================ +// Stage 6: Request Execution +// ============================================================================ + +/// Request execution stage: Execute gRPC requests (single or dual dispatch) +pub struct RequestExecutionStage { + mode: ExecutionMode, +} + +pub enum ExecutionMode { + /// Regular mode: single worker execution + Single, + /// PD mode: dual dispatch to prefill + decode workers + DualDispatch, +} + +impl RequestExecutionStage { + pub fn new(mode: ExecutionMode) -> Self { + Self { mode } + } +} + +#[async_trait] +impl PipelineStage for RequestExecutionStage { + async fn execute(&self, ctx: &mut RequestContext) -> Result, Response> { + debug!("Stage {}: Executing gRPC request", self.name()); + + let proto_request = ctx + .state + .proto_request + .take() + .ok_or_else(|| utils::internal_error_static("Proto request not built"))?; + + let clients = ctx + .state + .clients + .as_mut() + .ok_or_else(|| utils::internal_error_static("Client acquisition not completed"))?; + + let result = match self.mode { + ExecutionMode::Single => self.execute_single(proto_request, clients).await?, + ExecutionMode::DualDispatch => { + self.execute_dual_dispatch(proto_request, clients).await? + } + }; + + // Store result in context for ResponseProcessingStage + ctx.state.response.execution_result = Some(result); + Ok(None) + } + + fn name(&self) -> &'static str { + "RequestExecution" + } +} + +impl RequestExecutionStage { + async fn execute_single( + &self, + proto_request: proto::GenerateRequest, + clients: &mut ClientSelection, + ) -> Result { + let client = clients + .single_mut() + .ok_or_else(|| utils::internal_error_static("Expected single client but got dual"))?; + + let stream = client.generate(proto_request).await.map_err(|e| { + utils::internal_error_message(format!("Failed to start generation: {}", e)) + })?; + + Ok(ExecutionResult::Single { stream }) + } + + async fn execute_dual_dispatch( + &self, + proto_request: proto::GenerateRequest, + clients: &mut ClientSelection, + ) -> Result { + let (prefill_client, decode_client) = clients + .dual_mut() + .ok_or_else(|| utils::internal_error_static("Expected dual clients but got single"))?; + + debug!("Sending concurrent requests to prefill and decode workers"); + + let prefill_request = proto_request.clone(); + let decode_request = proto_request; + + let (prefill_result, decode_result) = tokio::join!( + prefill_client.generate(prefill_request), + decode_client.generate(decode_request) + ); + + // Handle prefill result + let prefill_stream = match prefill_result { + Ok(s) => s, + Err(e) => { + return Err(utils::internal_error_message(format!( + "Prefill worker failed to start: {}", + e + ))); + } + }; + + // Handle decode result + let decode_stream = match decode_result { + Ok(s) => s, + Err(e) => { + return Err(utils::internal_error_message(format!( + "Decode worker failed to start: {}", + e + ))); + } + }; + + Ok(ExecutionResult::Dual { + prefill: prefill_stream, + decode: Box::new(decode_stream), + }) + } +} + +// ============================================================================ +// Stage 7: Response Processing +// ============================================================================ + +/// Response processing stage: Handles both streaming and non-streaming responses +/// +/// - For streaming: Spawns background task and returns SSE response (early exit) +/// - For non-streaming: Collects all responses and builds final ChatCompletionResponse +pub struct ResponseProcessingStage { + processor: processing::ResponseProcessor, + streaming_processor: Arc, +} + +impl ResponseProcessingStage { + pub fn new( + processor: processing::ResponseProcessor, + streaming_processor: Arc, + ) -> Self { + Self { + processor, + streaming_processor, + } + } +} + +#[async_trait] +impl PipelineStage for ResponseProcessingStage { + async fn execute(&self, ctx: &mut RequestContext) -> Result, Response> { + debug!("Stage {}: Processing response", self.name()); + + // Delegate to request-type specific processing + match &ctx.input.request_type { + RequestType::Chat(_) => return self.process_chat_response(ctx).await, + RequestType::Generate(_) => return self.process_generate_response(ctx).await, + } + } + + fn name(&self) -> &'static str { + "ResponseProcessing" + } +} + +impl ResponseProcessingStage { + async fn process_chat_response( + &self, + ctx: &mut RequestContext, + ) -> Result, Response> { + let is_streaming = ctx.is_streaming(); + + // Extract execution result + let execution_result = ctx + .state + .response + .execution_result + .take() + .ok_or_else(|| utils::internal_error_static("No execution result"))?; + + if is_streaming { + // Get dispatch metadata for consistent response fields + let dispatch = ctx + .state + .dispatch + .as_ref() + .ok_or_else(|| utils::internal_error_static("Dispatch metadata not set"))?; + + // Streaming: Use StreamingProcessor and return SSE response (done) + return Ok(Some( + self.streaming_processor.clone().process_streaming_response( + execution_result, + ctx.chat_request().clone(), + dispatch.clone(), + ), + )); + } + + // Non-streaming: Extract chat request details before mutable borrows + let request_logprobs = match &ctx.input.request_type { + RequestType::Chat(req) => req.logprobs, + _ => false, + }; + + // Collect all responses from the execution result + let all_responses = match execution_result { + ExecutionResult::Single { stream } => { + utils::collect_stream_responses(stream, "Single").await? + } + ExecutionResult::Dual { prefill, decode } => { + // Collect prefill for input_logprobs + let prefill_responses = utils::collect_stream_responses(prefill, "Prefill").await?; + + // Collect decode for actual output + let mut decode_responses = + utils::collect_stream_responses(*decode, "Decode").await?; + + // Merge prefill input_logprobs if requested + if request_logprobs { + if let Some(prefill_input_logprobs) = prefill_responses + .first() + .and_then(|r| r.input_logprobs.clone()) + { + for response in &mut decode_responses { + response.input_logprobs = Some(prefill_input_logprobs.clone()); + } + } + } + + decode_responses + } + }; + + if all_responses.is_empty() { + return Err(utils::internal_error_static("No responses from server")); + } + + // Clone chat_request to avoid borrow checker conflict + // (ctx.chat_request() borrows ctx, preventing mutable borrow of ctx.state.response.stop_decoder) + let chat_request = ctx.chat_request().clone(); + let history_tool_calls_count = utils::get_history_tool_calls_count(&chat_request); + + let stop_decoder = ctx + .state + .response + .stop_decoder + .as_mut() + .ok_or_else(|| utils::internal_error_static("Stop decoder not initialized"))?; + + let mut choices = Vec::new(); + for (index, complete) in all_responses.iter().enumerate() { + match self + .processor + .process_single_choice( + complete, + index, + &chat_request, + stop_decoder, + history_tool_calls_count, + ) + .await + { + Ok(choice) => choices.push(choice), + Err(e) => { + return Err(utils::internal_error_message(format!( + "Failed to process choice {}: {}", + index, e + ))); + } + } + } + + // Build usage + let total_prompt_tokens: u32 = all_responses.iter().map(|r| r.prompt_tokens as u32).sum(); + let total_completion_tokens: u32 = all_responses + .iter() + .map(|r| r.completion_tokens as u32) + .sum(); + let usage = Usage { + prompt_tokens: total_prompt_tokens, + completion_tokens: total_completion_tokens, + total_tokens: total_prompt_tokens + total_completion_tokens, + completion_tokens_details: None, + }; + + // Build final ChatCompletionResponse + let dispatch = ctx + .state + .dispatch + .as_ref() + .ok_or_else(|| utils::internal_error_static("Dispatch metadata not set"))?; + + let response = ChatCompletionResponse { + id: dispatch.request_id.clone(), + object: "chat.completion".to_string(), + created: dispatch.created, + model: dispatch.model.clone(), + choices, + usage: Some(usage), + system_fingerprint: dispatch.weight_version.clone(), + }; + + // Store the final response + ctx.state.response.final_response = Some(FinalResponse::Chat(response)); + + Ok(None) + } + + async fn process_generate_response( + &self, + _ctx: &mut RequestContext, + ) -> Result, Response> { + // TODO(generate): Implement generate response processing + // + // Required implementation: + // 1. Extract execution_result from ctx + // 2. Check is_streaming flag + // 3. For streaming: + // - Add StreamingProcessor::process_streaming_generate() method + // - Similar to process_streaming_response but WITHOUT tool/reasoning parsing + // - Return Err(sse_response) for early exit + // 4. For non-streaming: + // - Collect stream responses using utils::collect_stream_responses() + // - Process through stop decoder (sequential with reset for n>1, like chat) + // - Build GenerateResponse struct (see TODO in protocols/spec.rs) + // - Set ctx.state.response.final_response = Some(FinalResponse::Generate(response)) + // + // Reference implementation: router.rs:297-595 + // Key differences from chat: + // - No tool parsing + // - No reasoning parsing + // - Different response format (GenerateResponse instead of ChatCompletionResponse) + // - Still needs: stop decoder, logprobs, finish_reason, matched_stop + Err(( + axum::http::StatusCode::NOT_IMPLEMENTED, + axum::Json(serde_json::json!({ + "error": { + "message": "Generate response processing not yet implemented in pipeline", + "type": "not_implemented", + "code": 501 + } + })), + ) + .into_response()) + } +} + +// ============================================================================ +// Pipeline Orchestrator +// ============================================================================ + +/// Complete chat completion pipeline +/// +/// Orchestrates all stages from request preparation to response delivery. +/// Configured differently for regular vs PD mode. +#[derive(Clone)] +pub struct ChatCompletionPipeline { + stages: Arc>>, +} + +impl ChatCompletionPipeline { + /// Create a regular (single-worker) pipeline + pub fn new_regular( + worker_registry: Arc, + policy_registry: Arc, + processor: processing::ResponseProcessor, + streaming_processor: Arc, + ) -> Self { + let stages: Vec> = vec![ + Box::new(PreparationStage), + Box::new(WorkerSelectionStage::new( + worker_registry, + policy_registry, + WorkerSelectionMode::Regular, + )), + Box::new(ClientAcquisitionStage), + Box::new(RequestBuildingStage::new(false)), // No PD metadata + Box::new(DispatchMetadataStage), + Box::new(RequestExecutionStage::new(ExecutionMode::Single)), + Box::new(ResponseProcessingStage::new( + processor, + streaming_processor.clone(), + )), + ]; + + Self { + stages: Arc::new(stages), + } + } + + /// Create a PD (prefill-decode) pipeline + pub fn new_pd( + worker_registry: Arc, + policy_registry: Arc, + processor: processing::ResponseProcessor, + streaming_processor: Arc, + ) -> Self { + let stages: Vec> = vec![ + Box::new(PreparationStage), + Box::new(WorkerSelectionStage::new( + worker_registry, + policy_registry, + WorkerSelectionMode::PrefillDecode, + )), + Box::new(ClientAcquisitionStage), + Box::new(RequestBuildingStage::new(true)), // Inject PD metadata + Box::new(DispatchMetadataStage), + Box::new(RequestExecutionStage::new(ExecutionMode::DualDispatch)), + Box::new(ResponseProcessingStage::new( + processor, + streaming_processor.clone(), + )), + ]; + + Self { + stages: Arc::new(stages), + } + } + + /// Execute the complete pipeline for a chat request + pub async fn execute_chat( + &self, + request: ChatCompletionRequest, + headers: Option, + model_id: Option, + components: Arc, + ) -> Response { + let mut ctx = RequestContext::for_chat(request, headers, model_id, components); + + // Execute each stage in sequence + for (idx, stage) in self.stages.iter().enumerate() { + debug!("Executing stage {}: {}", idx + 1, stage.name()); + match stage.execute(&mut ctx).await { + Ok(Some(response)) => { + // Stage completed successfully with a response (e.g., streaming) + debug!( + "Stage {} ({}) completed with response", + idx + 1, + stage.name() + ); + return response; + } + Ok(None) => { + // Continue to next stage + continue; + } + Err(response) => { + // Error occurred + error!( + "Stage {} ({}) failed with status {}", + idx + 1, + stage.name(), + response.status() + ); + return response; + } + } + } + + // Extract final response + match ctx.state.response.final_response { + Some(FinalResponse::Chat(response)) => axum::Json(response).into_response(), + Some(FinalResponse::Generate(_)) => { + utils::internal_error_static("Internal error: wrong response type") + } + None => utils::internal_error_static("No response produced"), + } + } + + /// Execute the complete pipeline for a generate request + pub async fn execute_generate( + &self, + request: GenerateRequest, + headers: Option, + model_id: Option, + components: Arc, + ) -> Response { + let mut ctx = RequestContext::for_generate(request, headers, model_id, components); + + // Execute each stage in sequence + for (idx, stage) in self.stages.iter().enumerate() { + debug!("Executing stage {}: {}", idx + 1, stage.name()); + match stage.execute(&mut ctx).await { + Ok(Some(response)) => { + // Stage completed successfully with a response (e.g., streaming) + debug!( + "Stage {} ({}) completed with response", + idx + 1, + stage.name() + ); + return response; + } + Ok(None) => { + // Continue to next stage + continue; + } + Err(response) => { + // Error occurred + error!( + "Stage {} ({}) failed with status {}", + idx + 1, + stage.name(), + response.status() + ); + return response; + } + } + } + + // Extract final response + match ctx.state.response.final_response { + Some(FinalResponse::Generate(response)) => axum::Json(*response).into_response(), + Some(FinalResponse::Chat(_)) => { + utils::internal_error_static("Internal error: wrong response type") + } + None => utils::internal_error_static("No response produced"), + } + } +} diff --git a/sgl-router/src/routers/grpc/processing.rs b/sgl-router/src/routers/grpc/processing.rs new file mode 100644 index 000000000..7451236fb --- /dev/null +++ b/sgl-router/src/routers/grpc/processing.rs @@ -0,0 +1,268 @@ +//! Shared response processing logic for gRPC routers +//! +//! This module contains response processing functions that are shared between +//! the regular router and PD router, eliminating ~1,200 lines of exact duplicates. + +use std::sync::Arc; + +use serde_json::Value; +use tracing::error; + +use crate::grpc_client::proto; +use crate::protocols::spec::{ + ChatChoice, ChatCompletionMessage, ChatCompletionRequest, FunctionCallResponse, ToolCall, + ToolChoice, ToolChoiceValue, +}; +use crate::reasoning_parser::ReasoningParserFactory; +use crate::tokenizer::stop::{SequenceDecoderOutput, StopSequenceDecoder}; +use crate::tokenizer::traits::Tokenizer; +use crate::tool_parser::ToolParserFactory; + +use super::utils; + +// ============================================================================ +// Response Processor - Main Entry Point +// ============================================================================ + +/// Unified response processor for both routers +#[derive(Clone)] +pub struct ResponseProcessor { + pub tokenizer: Arc, + pub tool_parser_factory: ToolParserFactory, + pub reasoning_parser_factory: ReasoningParserFactory, + configured_tool_parser: Option, + configured_reasoning_parser: Option, +} + +impl ResponseProcessor { + pub fn new( + tokenizer: Arc, + tool_parser_factory: ToolParserFactory, + reasoning_parser_factory: ReasoningParserFactory, + configured_tool_parser: Option, + configured_reasoning_parser: Option, + ) -> Self { + Self { + tokenizer, + tool_parser_factory, + reasoning_parser_factory, + configured_tool_parser, + configured_reasoning_parser, + } + } + + /// Process a single choice from GenerateComplete response (EXACT COPY from router.rs:1573-1725) + pub async fn process_single_choice( + &self, + complete: &proto::GenerateComplete, + index: usize, + original_request: &ChatCompletionRequest, + stop_decoder: &mut StopSequenceDecoder, + history_tool_calls_count: usize, + ) -> Result { + stop_decoder.reset(); + // Decode tokens + let outputs = stop_decoder + .process_tokens(&complete.output_ids) + .map_err(|e| format!("Failed to process tokens: {}", e))?; + + // Accumulate text with early breaks + let mut final_text = String::new(); + for output in outputs { + match output { + SequenceDecoderOutput::Text(t) => final_text.push_str(&t), + SequenceDecoderOutput::StoppedWithText(t) => { + final_text.push_str(&t); + break; + } + SequenceDecoderOutput::Stopped => break, + SequenceDecoderOutput::Held => {} + } + } + + // Flush remaining text + if let SequenceDecoderOutput::Text(t) = stop_decoder.flush() { + final_text.push_str(&t); + } + + // Step 1: Handle reasoning content parsing + let mut reasoning_text: Option = None; + let mut processed_text = final_text; + + // Check if reasoning parsing is enabled and separate_reasoning is requested + if original_request.separate_reasoning { + let pooled_parser = utils::get_reasoning_parser( + &self.reasoning_parser_factory, + self.configured_reasoning_parser.as_ref(), + &original_request.model, + ); + + let mut parser = pooled_parser + .lock() + .map_err(|e| format!("Failed to acquire reasoning parser lock: {}", e))?; + match parser.detect_and_parse_reasoning(&processed_text) { + Ok(result) => { + if !result.reasoning_text.is_empty() { + reasoning_text = Some(result.reasoning_text); + } + processed_text = result.normal_text; + } + Err(e) => { + return Err(format!("Reasoning parsing error: {}", e)); + } + } + } + + // Step 2: Handle tool call parsing + let mut tool_calls: Option> = None; + + // Check if tool calls should be processed + let tool_choice_enabled = !matches!( + &original_request.tool_choice, + Some(ToolChoice::Value(ToolChoiceValue::None)) + ); + + if tool_choice_enabled && original_request.tools.is_some() { + // Check if JSON schema constraint was used (specific function or required mode) + let used_json_schema = match &original_request.tool_choice { + Some(ToolChoice::Function { .. }) => true, + Some(ToolChoice::Value(ToolChoiceValue::Required)) => true, + Some(ToolChoice::AllowedTools { mode, .. }) => mode == "required", + _ => false, + }; + + if used_json_schema { + (tool_calls, processed_text) = utils::parse_json_schema_response( + &processed_text, + &original_request.tool_choice, + ); + } else { + (tool_calls, processed_text) = self + .parse_tool_calls( + &processed_text, + &original_request.model, + history_tool_calls_count, + ) + .await; + } + } + + // Step 3: Use finish reason directly from proto (already OpenAI-compatible string) + let finish_reason_str = &complete.finish_reason; + + // Override finish reason if we have tool calls + let final_finish_reason_str = if tool_calls.is_some() { + "tool_calls" + } else { + finish_reason_str + }; + + // Extract matched_stop information from proto + let matched_stop = match &complete.matched_stop { + Some(proto::generate_complete::MatchedStop::MatchedTokenId(token_id)) => { + Some(Value::Number(serde_json::Number::from(*token_id))) + } + Some(proto::generate_complete::MatchedStop::MatchedStopStr(stop_str)) => { + Some(Value::String(stop_str.clone())) + } + None => None, + }; + + // Step 4: Convert output logprobs if present + let logprobs = if let Some(proto_logprobs) = &complete.output_logprobs { + match utils::convert_proto_to_openai_logprobs(proto_logprobs, &self.tokenizer) { + Ok(logprobs) => Some(logprobs), + Err(e) => { + error!("Failed to convert logprobs: {}", e); + None + } + } + } else { + None + }; + + // Step 5: Build ChatCompletionMessage (proper response message type) + let chat_message = ChatCompletionMessage { + role: "assistant".to_string(), + content: if processed_text.is_empty() { + None + } else { + Some(processed_text) + }, + tool_calls, + reasoning_content: reasoning_text, + }; + + // Step 6: Build ChatChoice + let choice = ChatChoice { + index: index as u32, + message: chat_message, + logprobs, + finish_reason: Some(final_finish_reason_str.to_string()), + matched_stop, + hidden_states: None, + }; + + Ok(choice) + } + + /// Parse tool calls using model-specific parser (EXACT COPY from router.rs:296-361) + pub async fn parse_tool_calls( + &self, + processed_text: &str, + model: &str, + history_tool_calls_count: usize, + ) -> (Option>, String) { + // Get pooled parser for this model + let pooled_parser = utils::get_tool_parser( + &self.tool_parser_factory, + self.configured_tool_parser.as_ref(), + model, + ); + + // Try parsing directly (parser will handle detection internally) + let result = { + let parser = pooled_parser.lock().await; + parser.parse_complete(processed_text).await + // Lock is dropped here + }; + + match result { + Ok((normal_text, parsed_tool_calls)) => { + if parsed_tool_calls.is_empty() { + return (None, normal_text); + } + + let spec_tool_calls = parsed_tool_calls + .into_iter() + .enumerate() + .map(|(index, tc)| { + // Generate ID for this tool call + let id = utils::generate_tool_call_id( + model, + &tc.function.name, + index, + history_tool_calls_count, + ); + ToolCall { + id, + tool_type: "function".to_string(), + function: FunctionCallResponse { + name: tc.function.name, + arguments: Some( + serde_json::to_string(&tc.function.arguments) + .unwrap_or_else(|_| "{}".to_string()), + ), + }, + } + }) + .collect(); + (Some(spec_tool_calls), normal_text) + } + Err(e) => { + error!("Tool call parsing error: {}", e); + (None, processed_text.to_string()) + } + } + } +} diff --git a/sgl-router/src/routers/grpc/router.rs b/sgl-router/src/routers/grpc/router.rs index 8dc2316c9..6ccd18da6 100644 --- a/sgl-router/src/routers/grpc/router.rs +++ b/sgl-router/src/routers/grpc/router.rs @@ -1,44 +1,34 @@ // gRPC Router Implementation -use std::collections::HashMap; use std::sync::Arc; use async_trait::async_trait; use axum::{ body::Body, extract::Request, - http::{header::CONTENT_TYPE, HeaderMap, HeaderValue, StatusCode}, + http::{HeaderMap, StatusCode}, response::{IntoResponse, Response}, Json, }; -use bytes::Bytes; -use std::io; -use tokio::sync::mpsc; -use tokio_stream::wrappers::UnboundedReceiverStream; -use tracing::{debug, error, warn}; +use tracing::debug; use crate::config::types::RetryConfig; use crate::core::{ConnectionMode, Worker, WorkerRegistry, WorkerType}; use crate::grpc_client::{proto, SglangSchedulerClient}; use crate::policies::PolicyRegistry; use crate::protocols::spec::{ - ChatChoice, ChatCompletionMessage, ChatCompletionRequest, ChatCompletionResponse, - ChatCompletionStreamResponse, ChatMessage, ChatMessageDelta, ChatStreamChoice, - CompletionRequest, EmbeddingRequest, FunctionCallDelta, FunctionCallResponse, GenerateRequest, - RerankRequest, ResponsesGetParams, ResponsesRequest, StringOrArray, ToolCall, ToolCallDelta, - ToolChoice, ToolChoiceValue, Usage, + ChatCompletionRequest, CompletionRequest, EmbeddingRequest, GenerateRequest, InputIds, + RerankRequest, ResponsesGetParams, ResponsesRequest, }; -use crate::reasoning_parser::{ParserResult, ReasoningParserFactory}; +use crate::reasoning_parser::ReasoningParserFactory; use crate::routers::{grpc, RouterTrait}; use crate::server::AppContext; -use crate::tokenizer::stop::{SequenceDecoderOutput, StopSequenceDecoder}; +use crate::tokenizer::stop::SequenceDecoderOutput; use crate::tokenizer::traits::Tokenizer; -use crate::tool_parser::{StreamingParseResult, ToolParserFactory}; +use crate::tool_parser::ToolParserFactory; use grpc::utils; -use proto::generate_response::Response::{Chunk, Complete, Error}; -use serde_json::{json, Value}; -use std::time::{Instant, SystemTime, UNIX_EPOCH}; -use tokio_stream::StreamExt; +use serde_json::json; +use std::time::Instant; use uuid::Uuid; /// gRPC router implementation for SGLang @@ -55,6 +45,10 @@ pub struct GrpcRouter { retry_config: RetryConfig, configured_reasoning_parser: Option, configured_tool_parser: Option, + // Pipeline for non-streaming requests + pipeline: super::pipeline::ChatCompletionPipeline, + // Shared components for pipeline + shared_components: Arc, } impl GrpcRouter { @@ -80,6 +74,39 @@ impl GrpcRouter { let worker_registry = ctx.worker_registry.clone(); let policy_registry = ctx.policy_registry.clone(); + // Create shared components for pipeline + let shared_components = Arc::new(super::context::SharedComponents { + tokenizer: tokenizer.clone(), + tool_parser_factory: tool_parser_factory.clone(), + reasoning_parser_factory: reasoning_parser_factory.clone(), + }); + + // Create response processor + let processor = super::processing::ResponseProcessor::new( + tokenizer.clone(), + tool_parser_factory.clone(), + reasoning_parser_factory.clone(), + ctx.configured_tool_parser.clone(), + ctx.configured_reasoning_parser.clone(), + ); + + // Create streaming processor + let streaming_processor = Arc::new(super::streaming::StreamingProcessor::new( + tokenizer.clone(), + tool_parser_factory.clone(), + reasoning_parser_factory.clone(), + ctx.configured_tool_parser.clone(), + ctx.configured_reasoning_parser.clone(), + )); + + // Create pipeline + let pipeline = super::pipeline::ChatCompletionPipeline::new_regular( + worker_registry.clone(), + policy_registry.clone(), + processor, + streaming_processor, + ); + Ok(GrpcRouter { worker_registry, policy_registry, @@ -91,13 +118,15 @@ impl GrpcRouter { retry_config: ctx.router_config.effective_retry_config(), configured_reasoning_parser: ctx.configured_reasoning_parser.clone(), configured_tool_parser: ctx.configured_tool_parser.clone(), + pipeline, + shared_components, }) } /// Main route_chat implementation async fn route_chat_impl( &self, - _headers: Option<&HeaderMap>, + headers: Option<&HeaderMap>, body: &ChatCompletionRequest, model_id: Option<&str>, ) -> Response { @@ -106,76 +135,15 @@ impl GrpcRouter { model_id ); - // Step 1: Filter tools if needed for allowed_tools or specific function - let body_ref = utils::filter_tools_for_request(body); - - // Step 2: Process messages and apply chat template - let processed_messages = match utils::process_chat_messages(&body_ref, &*self.tokenizer) { - Ok(msgs) => msgs, - Err(e) => { - return utils::bad_request_error(e.to_string()); - } - }; - - // Step 3: Tokenize the processed text - let encoding = match self.tokenizer.encode(&processed_messages.text) { - Ok(encoding) => encoding, - Err(e) => { - return utils::internal_error_message(format!("Tokenization failed: {}", e)); - } - }; - - let token_ids = encoding.token_ids().to_vec(); - debug!("Tokenized {} tokens from input", token_ids.len()); - - // Step 4: Build tool constraints if needed - // body_ref already has filtered tools if needed - let tool_call_constraint = body_ref.tools.as_ref().and_then(|tools| { - utils::generate_tool_constraints(tools, &body.tool_choice, &body.model) - }); - - // Step 5: Select worker - let worker = match self.select_worker_for_request(model_id, Some(&processed_messages.text)) - { - Some(w) => w, - None => { - return utils::service_unavailable_error(format!( - "No available workers for model: {:?}", - model_id - )); - } - }; - - debug!("Selected worker: {}", worker.url()); - - // Step 6: Get gRPC client from worker - let client = match utils::get_grpc_client_from_worker(&worker).await { - Ok(client) => client, - Err(response) => return response, - }; - - // Step 7: Build the base gRPC request (use body_ref with filtered tools if applicable) - let request_id = format!("chatcmpl-{}", Uuid::new_v4()); - let request = match client.build_generate_request( - request_id, - &body_ref, - processed_messages.text.clone(), - token_ids, - processed_messages.multimodal_inputs, - tool_call_constraint, // Pass the full tuple (type, value) - ) { - Ok(request) => request, - Err(e) => { - return utils::bad_request_error(format!("Invalid request parameters: {}", e)); - } - }; - - // Step 7: Handle streaming vs non-streaming - if body.stream { - self.handle_streaming_chat(client, request, body).await - } else { - self.handle_non_streaming_chat(client, request, body).await - } + // Use pipeline for ALL requests (streaming and non-streaming) + self.pipeline + .execute_chat( + body.clone(), + headers.cloned(), + model_id.map(|s| s.to_string()), + self.shared_components.clone(), + ) + .await } /// Main route_generate implementation @@ -288,77 +256,6 @@ impl GrpcRouter { Some(available[idx].clone()) } - /// Parse tool calls using model-specific parser - async fn parse_tool_calls( - &self, - processed_text: &str, - model: &str, - history_tool_calls_count: usize, - ) -> (Option>, String) { - // Get pooled parser for this model - let pooled_parser = utils::get_tool_parser( - &self.tool_parser_factory, - self.configured_tool_parser.as_ref(), - model, - ); - - // Check format detection first - let can_parse = { - let parser = pooled_parser.lock().await; - parser.has_tool_markers(processed_text) - // Lock is dropped here - }; - - if !can_parse { - return (None, processed_text.to_string()); - } - - // Lock again for async parsing - let result = { - let parser = pooled_parser.lock().await; - parser.parse_complete(processed_text).await - // Lock is dropped here - }; - - match result { - Ok((normal_text, parsed_tool_calls)) => { - if parsed_tool_calls.is_empty() { - return (None, normal_text); - } - - let spec_tool_calls = parsed_tool_calls - .into_iter() - .enumerate() - .map(|(index, tc)| { - // Generate ID for this tool call - let id = Self::generate_tool_call_id( - model, - &tc.function.name, - index, - history_tool_calls_count, - ); - ToolCall { - id, - tool_type: "function".to_string(), - function: FunctionCallResponse { - name: tc.function.name, - arguments: Some( - serde_json::to_string(&tc.function.arguments) - .unwrap_or_else(|_| "{}".to_string()), - ), - }, - } - }) - .collect(); - (Some(spec_tool_calls), normal_text) - } - Err(e) => { - error!("Tool call parsing error: {}", e); - (None, processed_text.to_string()) - } - } - } - /// Resolve the generate input into optional original text and token IDs fn resolve_generate_input( &self, @@ -373,13 +270,13 @@ impl GrpcRouter { // Handle input_ids - validate and convert if let Some(input_ids) = &request.input_ids { return match input_ids { - crate::protocols::spec::InputIds::Single(ids) => ids + InputIds::Single(ids) => ids .iter() .map(|&id| u32::try_from(id)) .collect::, _>>() .map(|converted| (None, converted)) .map_err(|_| "input_ids must be non-negative".to_string()), - crate::protocols::spec::InputIds::Batch(_) => { + InputIds::Batch(_) => { Err("Batch input_ids are not supported over gRPC generate yet".to_string()) } }; @@ -396,837 +293,6 @@ impl GrpcRouter { Ok((text.to_string(), encoding.token_ids().to_vec())) } - /// Count the number of tool calls in the request message history - /// This is used for KimiK2 format which needs globally unique indices - fn get_history_tool_calls_count(request: &ChatCompletionRequest) -> usize { - request - .messages - .iter() - .filter_map(|msg| { - if let ChatMessage::Assistant { tool_calls, .. } = msg { - tool_calls.as_ref().map(|calls| calls.len()) - } else { - None - } - }) - .sum() - } - - /// Generate a tool call ID based on model format - /// - /// # Arguments - /// * `model` - Model name to determine ID format - /// * `tool_name` - Name of the tool being called - /// * `tool_index` - Index of this tool call within the current message - /// * `history_count` - Number of tool calls in previous messages - /// - /// # Returns - /// A unique ID string. KimiK2 uses `functions.{name}:{global_index}`, others use `call_{uuid}` - fn generate_tool_call_id( - model: &str, - tool_name: &str, - tool_index: usize, - history_count: usize, - ) -> String { - if model.to_lowercase().contains("kimi") { - // KimiK2 format: functions.{name}:{global_index} - format!("functions.{}:{}", tool_name, history_count + tool_index) - } else { - // Standard OpenAI format: call_{24-char-uuid} - format!("call_{}", &Uuid::new_v4().simple().to_string()[..24]) - } - } - - /// Process a chunk of tokens through the stop decoder - fn process_chunk_tokens( - stop_decoder: &mut StopSequenceDecoder, - token_ids: &[u32], - ) -> (String, bool) { - let mut chunk_text = String::new(); - - for &token_id in token_ids { - match stop_decoder.process_token(token_id).unwrap_or_else(|e| { - debug!( - "Error processing token {}: {}. Treating as Held.", - token_id, e - ); - SequenceDecoderOutput::Held - }) { - SequenceDecoderOutput::Text(text) => { - chunk_text.push_str(&text); - } - SequenceDecoderOutput::StoppedWithText(text) => { - chunk_text.push_str(&text); - return (chunk_text, true); // Return text and signal to stop - } - SequenceDecoderOutput::Stopped => { - return (chunk_text, true); // Return text and signal to stop - } - SequenceDecoderOutput::Held => { - // Text held for potential stop sequence match - } - } - } - (chunk_text, false) // Return text and continue processing - } - - /// Helper: Process reasoning content in streaming mode - /// Returns (modified_delta, optional_reasoning_chunk) - fn process_reasoning_stream( - &self, - delta: &str, - index: u32, - reasoning_parsers: &mut HashMap< - u32, - Arc>>, - >, - request_id: &str, - model: &str, - created: u64, - ) -> (String, Option, bool) { - // Get or create parser for this index - reasoning_parsers.entry(index).or_insert_with(|| { - utils::get_reasoning_parser( - &self.reasoning_parser_factory, - self.configured_reasoning_parser.as_ref(), - model, - ) - }); - - if let Some(pooled_parser) = reasoning_parsers.get(&index) { - let (parse_result, in_reasoning) = { - let mut parser = pooled_parser.lock().unwrap(); - let result = parser.parse_reasoning_streaming_incremental(delta); - let in_reasoning = parser.is_in_reasoning(); - (result, in_reasoning) - }; - - match parse_result { - Ok(ParserResult { - reasoning_text, - normal_text, - }) => { - let chunk = if !reasoning_text.is_empty() { - Some(ChatCompletionStreamResponse { - id: request_id.to_string(), - object: "chat.completion.chunk".to_string(), - created, - model: model.to_string(), - system_fingerprint: None, - choices: vec![ChatStreamChoice { - index, - delta: ChatMessageDelta { - role: Some("assistant".to_string()), - content: None, - tool_calls: None, - reasoning_content: Some(reasoning_text), - }, - logprobs: None, - finish_reason: None, - matched_stop: None, - }], - usage: None, - }) - } else { - None - }; - return (normal_text, chunk, in_reasoning); - } - Err(e) => { - warn!("Reasoning parsing error: {}", e); - } - } - } - - (delta.to_string(), None, false) - } - - /// Helper: Process tool calls in streaming mode - /// Returns (should_skip_content, chunks_to_emit) - #[allow(clippy::too_many_arguments)] - async fn process_tool_calls_stream( - &self, - delta: &str, - index: u32, - tool_parsers: &mut HashMap< - u32, - Arc>>, - >, - has_tool_calls: &mut HashMap, - tools: &[crate::protocols::spec::Tool], - request_id: &str, - model: &str, - created: u64, - history_tool_calls_count: usize, - ) -> (bool, Vec) { - let mut chunks = Vec::new(); - - // Get or create parser for this index - tool_parsers.entry(index).or_insert_with(|| { - utils::get_tool_parser( - &self.tool_parser_factory, - self.configured_tool_parser.as_ref(), - model, - ) - }); - - if let Some(pooled_parser) = tool_parsers.get(&index) { - let mut parser = pooled_parser.lock().await; - match parser.parse_incremental(delta, tools).await { - Ok(StreamingParseResult { normal_text, calls }) => { - // Emit normal text if present - if !normal_text.is_empty() { - chunks.push(ChatCompletionStreamResponse { - id: request_id.to_string(), - object: "chat.completion.chunk".to_string(), - created, - model: model.to_string(), - system_fingerprint: None, - choices: vec![ChatStreamChoice { - index, - delta: ChatMessageDelta { - role: Some("assistant".to_string()), - content: Some(normal_text), - tool_calls: None, - reasoning_content: None, - }, - logprobs: None, - finish_reason: None, - matched_stop: None, - }], - usage: None, - }); - } - - // Emit tool call chunks - for tool_call_item in calls { - has_tool_calls.insert(index, true); - - let tool_call_id = if let Some(ref name) = tool_call_item.name { - Some(Self::generate_tool_call_id( - model, - name, - tool_call_item.tool_index, - history_tool_calls_count, - )) - } else { - None - }; - - let tool_call_delta = ToolCallDelta { - index: tool_call_item.tool_index as u32, - id: tool_call_id, - tool_type: if tool_call_item.name.is_some() { - Some("function".to_string()) - } else { - None - }, - function: Some(FunctionCallDelta { - name: tool_call_item.name, - arguments: if !tool_call_item.parameters.is_empty() { - Some(tool_call_item.parameters) - } else { - None - }, - }), - }; - - chunks.push(ChatCompletionStreamResponse { - id: request_id.to_string(), - object: "chat.completion.chunk".to_string(), - created, - model: model.to_string(), - system_fingerprint: None, - choices: vec![ChatStreamChoice { - index, - delta: ChatMessageDelta { - role: Some("assistant".to_string()), - content: None, - tool_calls: Some(vec![tool_call_delta]), - reasoning_content: None, - }, - logprobs: None, - finish_reason: None, - matched_stop: None, - }], - usage: None, - }); - } - - // If we emitted chunks, skip regular content - return (!chunks.is_empty(), chunks); - } - Err(e) => { - warn!("Tool call parsing error: {}", e); - } - } - } - - (false, chunks) - } - - /// Helper: Create content chunk - fn create_content_chunk( - content: String, - index: u32, - request_id: &str, - model: &str, - created: u64, - logprobs: Option, - ) -> ChatCompletionStreamResponse { - ChatCompletionStreamResponse { - id: request_id.to_string(), - object: "chat.completion.chunk".to_string(), - created, - model: model.to_string(), - system_fingerprint: None, - choices: vec![ChatStreamChoice { - index, - delta: ChatMessageDelta { - role: Some("assistant".to_string()), - content: Some(content), - tool_calls: None, - reasoning_content: None, - }, - logprobs, - finish_reason: None, - matched_stop: None, - }], - usage: None, - } - } - - /// Helper: Format response as SSE chunk - fn format_sse_chunk(response: &ChatCompletionStreamResponse) -> String { - format!( - "data: {}\n\n", - serde_json::to_string(response).unwrap_or_default() - ) - } - - /// Submit request and handle streaming response for chat completions route - async fn handle_streaming_chat( - &self, - mut client: SglangSchedulerClient, - request: proto::GenerateRequest, - original_request: &ChatCompletionRequest, - ) -> Response { - let request_id = request.request_id.clone(); - let model = original_request.model.clone(); - - // Create channel for SSE streaming - let (tx, rx) = mpsc::unbounded_channel::>(); - - // Start the gRPC stream - let mut grpc_stream = match client.generate(request).await { - Ok(stream) => stream, - Err(e) => { - return utils::internal_error_message(format!("Generation failed: {}", e)); - } - }; - - let stop_params = ( - original_request.stop.clone(), - original_request.stop_token_ids.clone(), - original_request.skip_special_tokens, - original_request.no_stop_trim, - ); - - // Spawn processing task - let self_clone = self.clone(); - let original_request_clone = original_request.clone(); - tokio::spawn(async move { - let result = Self::process_streaming_chunks( - &self_clone, - &mut grpc_stream, - request_id, - model, - stop_params, - original_request_clone, - &tx, - ) - .await; - - if let Err(e) = result { - let error_chunk = format!( - "data: {}\n\n", - json!({ - "error": { - "message": e, - "type": "internal_error" - } - }) - ); - let _ = tx.send(Ok(Bytes::from(error_chunk))); - } - - // Send DONE marker - let _ = tx.send(Ok(Bytes::from("data: [DONE]\n\n"))); - }); - - // Create response with SSE headers - let stream = UnboundedReceiverStream::new(rx); - let mut response = Response::new(Body::from_stream(stream)); - *response.status_mut() = StatusCode::OK; - response - .headers_mut() - .insert(CONTENT_TYPE, HeaderValue::from_static("text/event-stream")); - response - .headers_mut() - .insert("Cache-Control", HeaderValue::from_static("no-cache")); - response - .headers_mut() - .insert("Connection", HeaderValue::from_static("keep-alive")); - response - } - - /// Process streaming chunks and send SSE events - async fn process_streaming_chunks( - router: &GrpcRouter, - grpc_stream: &mut (impl tokio_stream::Stream> - + Unpin), - request_id: String, - model: String, - stop_params: (Option, Option>, bool, bool), - original_request: ChatCompletionRequest, - tx: &mpsc::UnboundedSender>, - ) -> Result<(), String> { - // Extract request parameters - let separate_reasoning = original_request.separate_reasoning; - let tool_choice = &original_request.tool_choice; - let tools = &original_request.tools; - let history_tool_calls_count = Self::get_history_tool_calls_count(&original_request); - let stream_options = &original_request.stream_options; - - // Phase 1: Initialize state tracking (per-index for n>1 support) - let mut is_firsts: HashMap = HashMap::new(); - let mut stream_buffers: HashMap = HashMap::new(); - let mut finish_reasons: HashMap = HashMap::new(); - let mut matched_stops: HashMap> = HashMap::new(); - let mut prompt_tokens: HashMap = HashMap::new(); - let mut completion_tokens: HashMap = HashMap::new(); - let mut cached_tokens: HashMap = HashMap::new(); - - // Parser state (lazy initialization per index) - type PooledReasoningParser = - Arc>>; - let mut reasoning_parsers: HashMap = HashMap::new(); - - type PooledToolParser = Arc>>; - let mut tool_parsers: HashMap = HashMap::new(); - let mut has_tool_calls: HashMap = HashMap::new(); - - // Create stop decoder - let (stop, stop_token_ids, skip_special_tokens, no_stop_trim) = stop_params; - let mut stop_decoder = utils::create_stop_decoder( - &router.tokenizer, - stop.as_ref(), - stop_token_ids.as_ref(), - skip_special_tokens, - no_stop_trim, - ); - - let created = SystemTime::now() - .duration_since(UNIX_EPOCH) - .unwrap_or_default() - .as_secs(); - - // Phase 2: Main streaming loop - while let Some(response) = grpc_stream.next().await { - let gen_response = response.map_err(|e| format!("Stream error: {}", e))?; - - match gen_response.response { - Some(Chunk(chunk)) => { - let index = chunk.index; - - // Process tokens through stop decoder - let (chunk_text, _should_stop) = - Self::process_chunk_tokens(&mut stop_decoder, &chunk.token_ids); - - if chunk_text.is_empty() { - continue; - } - - // Process logprobs if present - let choice_logprobs = if let Some(ref proto_logprobs) = chunk.output_logprobs { - match router.convert_proto_to_openai_logprobs(proto_logprobs) { - Ok(logprobs) => Some(logprobs), - Err(e) => { - warn!("Failed to process logprobs: {}", e); - None - } - } - } else { - None - }; - - // Initialize stream buffer if first time - let stream_buffer = stream_buffers.entry(index).or_default(); - - // Send first chunk with role - if is_firsts.get(&index).copied().unwrap_or(true) { - let first_chunk = ChatCompletionStreamResponse { - id: request_id.clone(), - object: "chat.completion.chunk".to_string(), - created, - model: model.clone(), - system_fingerprint: None, - choices: vec![ChatStreamChoice { - index, - delta: ChatMessageDelta { - role: Some("assistant".to_string()), - content: None, - tool_calls: None, - reasoning_content: None, - }, - logprobs: None, - finish_reason: None, - matched_stop: None, - }], - usage: None, - }; - tx.send(Ok(Bytes::from(Self::format_sse_chunk(&first_chunk)))) - .map_err(|_| "Failed to send first chunk".to_string())?; - is_firsts.insert(index, false); - } - - // Calculate delta - let mut delta = chunk_text; - stream_buffer.push_str(&delta); - - // Reasoning content handling - let in_reasoning = if separate_reasoning { - let (normal_text, reasoning_chunk, in_reasoning) = router - .process_reasoning_stream( - &delta, - index, - &mut reasoning_parsers, - &request_id, - &model, - created, - ); - if let Some(chunk) = reasoning_chunk { - tx.send(Ok(Bytes::from(Self::format_sse_chunk(&chunk)))) - .map_err(|_| "Failed to send reasoning chunk".to_string())?; - } - delta = normal_text; - in_reasoning - } else { - false - }; - - // Tool call handling - let tool_choice_enabled = - !matches!(tool_choice, Some(ToolChoice::Value(ToolChoiceValue::None))); - - if !in_reasoning && tool_choice_enabled && tools.is_some() { - let (should_skip, tool_chunks) = router - .process_tool_calls_stream( - &delta, - index, - &mut tool_parsers, - &mut has_tool_calls, - tools.as_ref().unwrap(), - &request_id, - &model, - created, - history_tool_calls_count, - ) - .await; - - for chunk in tool_chunks { - tx.send(Ok(Bytes::from(Self::format_sse_chunk(&chunk)))) - .map_err(|_| "Failed to send tool call chunk".to_string())?; - } - - if should_skip { - continue; - } - } - - // Regular content emission - if !delta.is_empty() { - let content_chunk = Self::create_content_chunk( - delta, - index, - &request_id, - &model, - created, - choice_logprobs, - ); - tx.send(Ok(Bytes::from(Self::format_sse_chunk(&content_chunk)))) - .map_err(|_| "Failed to send content chunk".to_string())?; - } - } - Some(Complete(complete)) => { - // Flush any remaining text - if let SequenceDecoderOutput::Text(text) = stop_decoder.flush() { - if !text.is_empty() { - let index = complete.index; - let stream_buffer = stream_buffers.entry(index).or_default(); - stream_buffer.push_str(&text); - - let content_chunk = ChatCompletionStreamResponse { - id: request_id.clone(), - object: "chat.completion.chunk".to_string(), - created, - model: model.clone(), - system_fingerprint: None, - choices: vec![ChatStreamChoice { - index, - delta: ChatMessageDelta { - role: Some("assistant".to_string()), - content: Some(text), - tool_calls: None, - reasoning_content: None, - }, - logprobs: None, - finish_reason: None, - matched_stop: None, - }], - usage: None, - }; - - let sse_chunk = serde_json::to_string(&content_chunk) - .map_err(|e| format!("Failed to serialize content chunk: {}", e))?; - tx.send(Ok(Bytes::from(format!("data: {}\n\n", sse_chunk)))) - .map_err(|_| "Failed to send flushed content".to_string())?; - } - } - - // Store metadata - let index = complete.index; - prompt_tokens.insert(index, complete.prompt_tokens as u32); - completion_tokens.insert(index, complete.completion_tokens as u32); - cached_tokens.insert(index, complete.cached_tokens as u32); - finish_reasons.insert(index, complete.finish_reason.clone()); - - // Extract matched_stop - let matched_stop_value = match &complete.matched_stop { - Some(proto::generate_complete::MatchedStop::MatchedTokenId(token_id)) => { - Some(Value::Number(serde_json::Number::from(*token_id))) - } - Some(proto::generate_complete::MatchedStop::MatchedStopStr(stop_str)) => { - Some(Value::String(stop_str.clone())) - } - None => None, - }; - matched_stops.insert(index, matched_stop_value); - - break; - } - Some(Error(error)) => { - return Err(error.message); - } - None => continue, - } - } - - // Phase 3: Check unstreamed tool args - // Check if parsers have any remaining arguments that haven't been streamed yet - for (index, parser) in &tool_parsers { - let parser_guard = parser.lock().await; - if let Some(unstreamed_items) = parser_guard.get_unstreamed_tool_args() { - for tool_call_item in unstreamed_items { - let tool_call_delta = ToolCallDelta { - index: tool_call_item.tool_index as u32, - id: None, - tool_type: None, // No type for argument deltas - function: Some(FunctionCallDelta { - name: None, // No name for argument deltas - arguments: if !tool_call_item.parameters.is_empty() { - Some(tool_call_item.parameters) - } else { - None - }, - }), - }; - - let tool_chunk = ChatCompletionStreamResponse { - id: request_id.clone(), - object: "chat.completion.chunk".to_string(), - created, - model: model.clone(), - system_fingerprint: None, - choices: vec![ChatStreamChoice { - index: *index, - delta: ChatMessageDelta { - role: Some("assistant".to_string()), - content: None, - tool_calls: Some(vec![tool_call_delta]), - reasoning_content: None, - }, - logprobs: None, - finish_reason: None, - matched_stop: None, - }], - usage: None, - }; - - let sse_chunk = serde_json::to_string(&tool_chunk) - .map_err(|e| format!("Failed to serialize tool chunk: {}", e))?; - tx.send(Ok(Bytes::from(format!("data: {}\n\n", sse_chunk)))) - .map_err(|_| "Failed to send unstreamed tool args".to_string())?; - } - } - } - - // Phase 4: Finish reason chunks - for (index, finish_reason) in finish_reasons.iter() { - let final_finish_reason = - if has_tool_calls.get(index).copied().unwrap_or(false) && finish_reason == "stop" { - "tool_calls".to_string() - } else { - finish_reason.clone() - }; - - let matched_stop_value = matched_stops.get(index).and_then(|v| v.clone()); - - let finish_chunk = ChatCompletionStreamResponse { - id: request_id.clone(), - object: "chat.completion.chunk".to_string(), - created, - model: model.clone(), - system_fingerprint: None, - choices: vec![ChatStreamChoice { - index: *index, - delta: ChatMessageDelta { - role: Some("assistant".to_string()), - content: None, - tool_calls: None, - reasoning_content: None, - }, - logprobs: None, - finish_reason: Some(final_finish_reason), - matched_stop: matched_stop_value, - }], - usage: None, - }; - - let sse_chunk = serde_json::to_string(&finish_chunk) - .map_err(|e| format!("Failed to serialize finish chunk: {}", e))?; - tx.send(Ok(Bytes::from(format!("data: {}\n\n", sse_chunk)))) - .map_err(|_| "Failed to send finish chunk".to_string())?; - } - - // Phase 5: Usage chunk - if let Some(stream_opts) = stream_options { - if stream_opts.include_usage.unwrap_or(false) { - let total_prompt: u32 = prompt_tokens.values().sum(); - let total_completion: u32 = completion_tokens.values().sum(); - - let usage_chunk = ChatCompletionStreamResponse { - id: request_id.clone(), - object: "chat.completion.chunk".to_string(), - created, - model: model.clone(), - system_fingerprint: None, - choices: vec![], - usage: Some(Usage { - prompt_tokens: total_prompt, - completion_tokens: total_completion, - total_tokens: total_prompt + total_completion, - completion_tokens_details: None, - }), - }; - - let sse_chunk = serde_json::to_string(&usage_chunk) - .map_err(|e| format!("Failed to serialize usage chunk: {}", e))?; - tx.send(Ok(Bytes::from(format!("data: {}\n\n", sse_chunk)))) - .map_err(|_| "Failed to send usage chunk".to_string())?; - } - } - - Ok(()) - } - - /// Submit request and handle non-streaming response for chat completions route - async fn handle_non_streaming_chat( - &self, - mut client: SglangSchedulerClient, - request: proto::GenerateRequest, - original_request: &ChatCompletionRequest, - ) -> Response { - let mut stop_decoder = utils::create_stop_decoder( - &self.tokenizer, - original_request.stop.as_ref(), - original_request.stop_token_ids.as_ref(), - original_request.skip_special_tokens, - original_request.no_stop_trim, - ); - - // Start generation - let stream = match client.generate(request).await { - Ok(s) => s, - Err(e) => { - return utils::internal_error_message(format!("Failed to start generation: {}", e)) - } - }; - - let all_responses = match utils::collect_stream_responses(stream, "Regular").await { - Ok(responses) => responses, - Err(err_response) => return err_response, - }; - - if all_responses.is_empty() { - return utils::internal_error_static("No responses from server"); - } - - // Process each response into a ChatChoice - let history_tool_calls_count = Self::get_history_tool_calls_count(original_request); - let mut choices = Vec::new(); - for (index, complete) in all_responses.iter().enumerate() { - match self - .process_single_choice( - complete, - index, - original_request, - &mut stop_decoder, - history_tool_calls_count, - ) - .await - { - Ok(choice) => choices.push(choice), - Err(e) => { - return utils::internal_error_message(format!( - "Failed to process choice {}: {}", - index, e - )); - } - } - } - - // Aggregate usage information from all responses - let total_prompt_tokens: u32 = all_responses.iter().map(|r| r.prompt_tokens as u32).sum(); - let total_completion_tokens: u32 = all_responses - .iter() - .map(|r| r.completion_tokens as u32) - .sum(); - let usage = Usage { - prompt_tokens: total_prompt_tokens, - completion_tokens: total_completion_tokens, - total_tokens: total_prompt_tokens + total_completion_tokens, - completion_tokens_details: None, - }; - - // Build final ChatCompletionResponse - let response = ChatCompletionResponse { - id: format!("chatcmpl-{}", Uuid::new_v4()), - object: "chat.completion".to_string(), - created: SystemTime::now() - .duration_since(UNIX_EPOCH) - .unwrap_or_default() - .as_secs(), - model: original_request.model.clone(), - choices, - usage: Some(usage), - system_fingerprint: None, - }; - - // Serialize and return JSON response - Json(response).into_response() - } - /// Submit request and handle non-streaming response for the `/generate` endpoint async fn handle_non_streaming_generate( &self, @@ -1498,234 +564,6 @@ impl GrpcRouter { Ok(()) } - - /// Convert proto LogProbs to OpenAI ChatLogProbs format - /// Note: Always decodes with skip_special_tokens=false to show actual tokens generated - fn convert_proto_to_openai_logprobs( - &self, - proto_logprobs: &proto::OutputLogProbs, - ) -> Result { - let mut content_items = Vec::new(); - - // Decode token IDs to text (always with skip_special_tokens=false for logprobs) - let token_texts: Vec = proto_logprobs - .token_ids - .iter() - .map(|&token_id| { - self.tokenizer - .decode(&[token_id as u32], false) - .unwrap_or_else(|_| format!("", token_id)) - }) - .collect(); - - // Build ChatLogProbsContent for each token (consume iterator to avoid clones) - for (i, (&logprob, token_text)) in proto_logprobs - .token_logprobs - .iter() - .zip(token_texts.into_iter()) - .enumerate() - { - let bytes = Some(token_text.as_bytes().to_vec()); - - // Build top_logprobs for this position - let mut top_logprobs = Vec::new(); - if let Some(top_logprobs_entry) = proto_logprobs.top_logprobs.get(i) { - // Decode top token IDs (always with skip_special_tokens=false) - let top_token_texts: Vec = top_logprobs_entry - .token_ids - .iter() - .map(|&tid| { - self.tokenizer - .decode(&[tid as u32], false) - .unwrap_or_else(|_| format!("", tid)) - }) - .collect(); - - for (j, (&top_logprob, &_top_token_id)) in top_logprobs_entry - .values - .iter() - .zip(top_logprobs_entry.token_ids.iter()) - .enumerate() - { - if let Some(top_token_text) = top_token_texts.get(j) { - top_logprobs.push(crate::protocols::spec::TopLogProb { - token: top_token_text.clone(), - logprob: top_logprob, - bytes: Some(top_token_text.as_bytes().to_vec()), - }); - } - } - } - - content_items.push(crate::protocols::spec::ChatLogProbsContent { - token: token_text, - logprob, - bytes, - top_logprobs, - }); - } - - Ok(crate::protocols::spec::ChatLogProbs::Detailed { - content: (!content_items.is_empty()).then_some(content_items), - }) - } - - /// Process a single GenerateComplete response into a ChatChoice - async fn process_single_choice( - &self, - complete: &proto::GenerateComplete, - index: usize, - original_request: &ChatCompletionRequest, - stop_decoder: &mut StopSequenceDecoder, - history_tool_calls_count: usize, - ) -> Result { - stop_decoder.reset(); - // Decode tokens - let outputs = stop_decoder - .process_tokens(&complete.output_ids) - .map_err(|e| format!("Failed to process tokens: {}", e))?; - - // Accumulate text with early breaks - let mut final_text = String::new(); - for output in outputs { - match output { - SequenceDecoderOutput::Text(t) => final_text.push_str(&t), - SequenceDecoderOutput::StoppedWithText(t) => { - final_text.push_str(&t); - break; - } - SequenceDecoderOutput::Stopped => break, - SequenceDecoderOutput::Held => {} - } - } - - // Flush remaining text - if let SequenceDecoderOutput::Text(t) = stop_decoder.flush() { - final_text.push_str(&t); - } - - // Step 1: Handle reasoning content parsing - let mut reasoning_text: Option = None; - let mut processed_text = final_text; - - // Check if reasoning parsing is enabled and separate_reasoning is requested - if original_request.separate_reasoning { - let pooled_parser = utils::get_reasoning_parser( - &self.reasoning_parser_factory, - self.configured_reasoning_parser.as_ref(), - &original_request.model, - ); - - let mut parser = pooled_parser - .lock() - .map_err(|e| format!("Failed to acquire reasoning parser lock: {}", e))?; - match parser.detect_and_parse_reasoning(&processed_text) { - Ok(result) => { - if !result.reasoning_text.is_empty() { - reasoning_text = Some(result.reasoning_text); - } - processed_text = result.normal_text; - } - Err(e) => { - return Err(format!("Reasoning parsing error: {}", e)); - } - } - } - - // Step 2: Handle tool call parsing - let mut tool_calls: Option> = None; - - // Check if tool calls should be processed - let tool_choice_enabled = !matches!( - &original_request.tool_choice, - Some(ToolChoice::Value(ToolChoiceValue::None)) - ); - - if tool_choice_enabled && original_request.tools.is_some() { - // Check if JSON schema constraint was used (specific function or required mode) - let used_json_schema = match &original_request.tool_choice { - Some(ToolChoice::Function { .. }) => true, - Some(ToolChoice::Value(ToolChoiceValue::Required)) => true, - Some(ToolChoice::AllowedTools { mode, .. }) => mode == "required", - _ => false, - }; - - if used_json_schema { - (tool_calls, processed_text) = utils::parse_json_schema_response( - &processed_text, - &original_request.tool_choice, - ); - } else { - (tool_calls, processed_text) = self - .parse_tool_calls( - &processed_text, - &original_request.model, - history_tool_calls_count, - ) - .await; - } - } - - // Step 3: Use finish reason directly from proto (already OpenAI-compatible string) - let finish_reason_str = &complete.finish_reason; - - // Override finish reason if we have tool calls - let final_finish_reason_str = if tool_calls.is_some() { - "tool_calls" - } else { - finish_reason_str - }; - - // Extract matched_stop information from proto - let matched_stop = match &complete.matched_stop { - Some(proto::generate_complete::MatchedStop::MatchedTokenId(token_id)) => { - Some(Value::Number(serde_json::Number::from(*token_id))) - } - Some(proto::generate_complete::MatchedStop::MatchedStopStr(stop_str)) => { - Some(Value::String(stop_str.clone())) - } - None => None, - }; - - // Step 4: Convert output logprobs if present - // Note: complete.input_logprobs exists in proto but is not used for chat completions - // (input logprobs are only used in /v1/completions endpoint with echo=true) - let logprobs = if let Some(proto_logprobs) = &complete.output_logprobs { - match self.convert_proto_to_openai_logprobs(proto_logprobs) { - Ok(logprobs) => Some(logprobs), - Err(e) => { - error!("Failed to convert logprobs: {}", e); - None - } - } - } else { - None - }; - - // Step 5: Build ChatCompletionMessage (proper response message type) - let chat_message = ChatCompletionMessage { - role: "assistant".to_string(), - content: if processed_text.is_empty() { - None - } else { - Some(processed_text) - }, - tool_calls, - reasoning_content: reasoning_text, - }; - - // Step 6: Build ChatChoice - let choice = ChatChoice { - index: index as u32, - message: chat_message, - logprobs, - finish_reason: Some(final_finish_reason_str.to_string()), - matched_stop, - hidden_states: None, - }; - - Ok(choice) - } } impl std::fmt::Debug for GrpcRouter { diff --git a/sgl-router/src/routers/grpc/streaming.rs b/sgl-router/src/routers/grpc/streaming.rs new file mode 100644 index 000000000..0337ce365 --- /dev/null +++ b/sgl-router/src/routers/grpc/streaming.rs @@ -0,0 +1,861 @@ +//! Streaming response processor for gRPC routers +//! +//! This module contains shared streaming logic for both Regular and PD routers, +//! eliminating ~600 lines of duplication. + +use axum::response::Response; +use axum::{body::Body, http::StatusCode}; +use bytes::Bytes; +use http::header::{HeaderValue, CONTENT_TYPE}; +use serde_json::{json, Value}; +use std::collections::HashMap; +use std::io; +use std::sync::Arc; +use tokio::sync::mpsc::UnboundedSender; +use tokio_stream::wrappers::UnboundedReceiverStream; +use tokio_stream::StreamExt; +use tonic::codec::Streaming; +use tracing::{debug, error, warn}; + +use crate::grpc_client::proto; +use crate::protocols::spec::*; +use crate::reasoning_parser::ReasoningParser; +use crate::tokenizer::stop::{SequenceDecoderOutput, StopSequenceDecoder}; +use crate::tokenizer::traits::Tokenizer; +use crate::tool_parser::ToolParser; + +use super::context; +use super::utils; + +/// Shared streaming processor for both single and dual dispatch modes +#[derive(Clone)] +pub struct StreamingProcessor { + tokenizer: Arc, + tool_parser_factory: crate::tool_parser::ToolParserFactory, + reasoning_parser_factory: crate::reasoning_parser::ReasoningParserFactory, + configured_tool_parser: Option, + configured_reasoning_parser: Option, +} + +impl StreamingProcessor { + pub fn new( + tokenizer: Arc, + tool_parser_factory: crate::tool_parser::ToolParserFactory, + reasoning_parser_factory: crate::reasoning_parser::ReasoningParserFactory, + configured_tool_parser: Option, + configured_reasoning_parser: Option, + ) -> Self { + Self { + tokenizer, + tool_parser_factory, + reasoning_parser_factory, + configured_tool_parser, + configured_reasoning_parser, + } + } + + /// Process streaming chat response and return SSE response + /// + /// This is the high-level entry point for streaming responses, handling: + /// - Channel creation + /// - Background task spawning + /// - SSE response building + pub fn process_streaming_response( + self: Arc, + execution_result: context::ExecutionResult, + chat_request: ChatCompletionRequest, + dispatch: context::DispatchMetadata, + ) -> axum::response::Response { + use bytes::Bytes; + use tokio::sync::mpsc; + + let stop_params = ( + chat_request.stop.clone(), + chat_request.stop_token_ids.clone(), + chat_request.skip_special_tokens, + chat_request.no_stop_trim, + ); + + // Create SSE channel + let (tx, rx) = mpsc::unbounded_channel::>(); + + // Spawn background task based on execution mode + match execution_result { + context::ExecutionResult::Single { stream } => { + let processor = self.clone(); + let dispatch_clone = dispatch.clone(); + tokio::spawn(async move { + let result = processor + .process_streaming_chunks( + stream, + dispatch_clone, + stop_params, + chat_request, + &tx, + ) + .await; + + if let Err(e) = result { + let error_chunk = format!( + "data: {}\n\n", + json!({ + "error": { + "message": e, + "type": "internal_error" + } + }) + ); + let _ = tx.send(Ok(Bytes::from(error_chunk))); + } + + let _ = tx.send(Ok(Bytes::from("data: [DONE]\n\n"))); + }); + } + context::ExecutionResult::Dual { prefill, decode } => { + let processor = self.clone(); + tokio::spawn(async move { + let result = processor + .process_dual_streaming_chunks( + prefill, + *decode, + dispatch, + stop_params, + chat_request, + &tx, + ) + .await; + + if let Err(e) = result { + let error_chunk = format!( + "data: {}\n\n", + json!({ + "error": { + "message": e, + "type": "internal_error" + } + }) + ); + let _ = tx.send(Ok(Bytes::from(error_chunk))); + } + + let _ = tx.send(Ok(Bytes::from("data: [DONE]\n\n"))); + }); + } + } + + // Return SSE response + build_sse_response(rx) + } + + /// Process streaming chunks from a single stream (Regular mode) + pub async fn process_streaming_chunks( + &self, + mut grpc_stream: Streaming, + dispatch: context::DispatchMetadata, + stop_params: (Option, Option>, bool, bool), + original_request: ChatCompletionRequest, + tx: &UnboundedSender>, + ) -> Result<(), String> { + // Extract request parameters + let separate_reasoning = original_request.separate_reasoning; + let tool_choice = &original_request.tool_choice; + let tools = &original_request.tools; + let history_tool_calls_count = utils::get_history_tool_calls_count(&original_request); + let stream_options = &original_request.stream_options; + + // Phase 1: Initialize state tracking (per-index for n>1 support) + let mut is_firsts: HashMap = HashMap::new(); + let mut stream_buffers: HashMap = HashMap::new(); + let mut finish_reasons: HashMap = HashMap::new(); + let mut matched_stops: HashMap> = HashMap::new(); + let mut prompt_tokens: HashMap = HashMap::new(); + let mut completion_tokens: HashMap = HashMap::new(); + let mut cached_tokens: HashMap = HashMap::new(); + + // Parser state (lazy initialization per index) + type PooledReasoningParser = Arc>>; + let mut reasoning_parsers: HashMap = HashMap::new(); + + type PooledToolParser = Arc>>; + let mut tool_parsers: HashMap = HashMap::new(); + let mut has_tool_calls: HashMap = HashMap::new(); + + // Per-index stop decoders (each index needs its own state for n>1 support) + let mut stop_decoders: HashMap = HashMap::new(); + + // Use dispatch metadata for consistent response fields + let request_id = &dispatch.request_id; + let model = &dispatch.model; + let created = dispatch.created; + let system_fingerprint = dispatch.weight_version.as_deref(); + + // Phase 2: Main streaming loop + while let Some(response) = grpc_stream.next().await { + let gen_response = response.map_err(|e| format!("Stream error: {}", e))?; + + match gen_response.response { + Some(proto::generate_response::Response::Chunk(chunk)) => { + let index = chunk.index; + + // Get or create stop decoder for this index + let stop_decoder = stop_decoders.entry(index).or_insert_with(|| { + let (ref stop, ref stop_token_ids, skip_special_tokens, no_stop_trim) = + stop_params; + utils::create_stop_decoder( + &self.tokenizer, + stop.as_ref(), + stop_token_ids.as_ref(), + skip_special_tokens, + no_stop_trim, + ) + }); + + // Process tokens through stop decoder + let (chunk_text, _should_stop) = + Self::process_chunk_tokens(stop_decoder, &chunk.token_ids); + + if chunk_text.is_empty() { + continue; + } + + // Process logprobs if present + let choice_logprobs = if let Some(ref proto_logprobs) = chunk.output_logprobs { + match utils::convert_proto_to_openai_logprobs( + proto_logprobs, + &self.tokenizer, + ) { + Ok(logprobs) => Some(logprobs), + Err(e) => { + warn!("Failed to process logprobs: {}", e); + None + } + } + } else { + None + }; + + // Initialize stream buffer if first time + let stream_buffer = stream_buffers.entry(index).or_default(); + + // Send first chunk with role + if is_firsts.get(&index).copied().unwrap_or(true) { + let first_chunk = ChatCompletionStreamResponse { + id: request_id.clone(), + object: "chat.completion.chunk".to_string(), + created, + model: model.clone(), + system_fingerprint: system_fingerprint.map(|s| s.to_string()), + choices: vec![ChatStreamChoice { + index, + delta: ChatMessageDelta { + role: Some("assistant".to_string()), + content: None, + tool_calls: None, + reasoning_content: None, + }, + logprobs: None, + finish_reason: None, + matched_stop: None, + }], + usage: None, + }; + tx.send(Ok(Bytes::from(Self::format_sse_chunk(&first_chunk)))) + .map_err(|_| "Failed to send first chunk".to_string())?; + is_firsts.insert(index, false); + } + + // Calculate delta + let mut delta = chunk_text; + stream_buffer.push_str(&delta); + + // Reasoning content handling + let in_reasoning = if separate_reasoning { + let (normal_text, reasoning_chunk, in_reasoning) = self + .process_reasoning_stream( + &delta, + index, + &mut reasoning_parsers, + request_id, + model, + created, + system_fingerprint, + ); + if let Some(chunk) = reasoning_chunk { + tx.send(Ok(Bytes::from(Self::format_sse_chunk(&chunk)))) + .map_err(|_| "Failed to send reasoning chunk".to_string())?; + } + delta = normal_text; + in_reasoning + } else { + false + }; + + // Tool call handling + let tool_choice_enabled = + !matches!(tool_choice, Some(ToolChoice::Value(ToolChoiceValue::None))); + + if !in_reasoning && tool_choice_enabled && tools.is_some() { + let (should_skip, tool_chunks) = self + .process_tool_calls_stream( + &delta, + index, + &mut tool_parsers, + &mut has_tool_calls, + tools.as_ref().unwrap(), + request_id, + model, + created, + system_fingerprint, + history_tool_calls_count, + ) + .await; + + for chunk in tool_chunks { + tx.send(Ok(Bytes::from(Self::format_sse_chunk(&chunk)))) + .map_err(|_| "Failed to send tool call chunk".to_string())?; + } + + // Continue to process the next chunk as we have tool chunks + if should_skip { + continue; + } + } + + // Regular content emission + if !delta.is_empty() { + let content_chunk = Self::create_content_chunk( + delta, + index, + request_id, + model, + created, + system_fingerprint, + choice_logprobs, + ); + tx.send(Ok(Bytes::from(Self::format_sse_chunk(&content_chunk)))) + .map_err(|_| "Failed to send content chunk".to_string())?; + } + } + Some(proto::generate_response::Response::Complete(complete)) => { + let index = complete.index; + + // Flush any remaining text for this index's stop_decoder + if let Some(decoder) = stop_decoders.get_mut(&index) { + if let SequenceDecoderOutput::Text(text) = decoder.flush() { + if !text.is_empty() { + let stream_buffer = stream_buffers.entry(index).or_default(); + stream_buffer.push_str(&text); + + let content_chunk = ChatCompletionStreamResponse { + id: request_id.clone(), + object: "chat.completion.chunk".to_string(), + created, + model: model.clone(), + system_fingerprint: system_fingerprint.map(|s| s.to_string()), + choices: vec![ChatStreamChoice { + index, + delta: ChatMessageDelta { + role: Some("assistant".to_string()), + content: Some(text), + tool_calls: None, + reasoning_content: None, + }, + logprobs: None, + finish_reason: None, + matched_stop: None, + }], + usage: None, + }; + + let sse_chunk = + serde_json::to_string(&content_chunk).map_err(|e| { + format!("Failed to serialize content chunk: {}", e) + })?; + tx.send(Ok(Bytes::from(format!("data: {}\n\n", sse_chunk)))) + .map_err(|_| "Failed to send flushed content".to_string())?; + } + } + } + + // Store metadata + prompt_tokens.insert(index, complete.prompt_tokens as u32); + completion_tokens.insert(index, complete.completion_tokens as u32); + cached_tokens.insert(index, complete.cached_tokens as u32); + finish_reasons.insert(index, complete.finish_reason.clone()); + + // Extract matched_stop + let matched_stop_value = match &complete.matched_stop { + Some(proto::generate_complete::MatchedStop::MatchedTokenId(token_id)) => { + Some(Value::Number(serde_json::Number::from(*token_id))) + } + Some(proto::generate_complete::MatchedStop::MatchedStopStr(stop_str)) => { + Some(Value::String(stop_str.clone())) + } + None => None, + }; + matched_stops.insert(index, matched_stop_value); + + // Don't break - continue reading all Complete messages for n>1 + } + Some(proto::generate_response::Response::Error(error)) => { + return Err(error.message); + } + None => continue, + } + } + + // Phase 3: Check unstreamed tool args + for (index, parser) in &tool_parsers { + let parser_guard = parser.lock().await; + if let Some(unstreamed_items) = parser_guard.get_unstreamed_tool_args() { + for tool_call_item in unstreamed_items { + let tool_call_delta = ToolCallDelta { + index: tool_call_item.tool_index as u32, + id: None, + tool_type: None, + function: Some(FunctionCallDelta { + name: None, + arguments: if !tool_call_item.parameters.is_empty() { + Some(tool_call_item.parameters) + } else { + None + }, + }), + }; + + let tool_chunk = ChatCompletionStreamResponse { + id: request_id.clone(), + object: "chat.completion.chunk".to_string(), + created, + model: model.clone(), + system_fingerprint: system_fingerprint.map(|s| s.to_string()), + choices: vec![ChatStreamChoice { + index: *index, + delta: ChatMessageDelta { + role: Some("assistant".to_string()), + content: None, + tool_calls: Some(vec![tool_call_delta]), + reasoning_content: None, + }, + logprobs: None, + finish_reason: None, + matched_stop: None, + }], + usage: None, + }; + + let sse_chunk = serde_json::to_string(&tool_chunk) + .map_err(|e| format!("Failed to serialize tool chunk: {}", e))?; + tx.send(Ok(Bytes::from(format!("data: {}\n\n", sse_chunk)))) + .map_err(|_| "Failed to send unstreamed tool args".to_string())?; + } + } + } + + // Phase 4: Finish reason chunks + for (index, finish_reason) in finish_reasons.iter() { + let final_finish_reason = + if has_tool_calls.get(index).copied().unwrap_or(false) && finish_reason == "stop" { + "tool_calls".to_string() + } else { + finish_reason.clone() + }; + + let matched_stop_value = matched_stops.get(index).and_then(|v| v.clone()); + + let finish_chunk = ChatCompletionStreamResponse { + id: request_id.clone(), + object: "chat.completion.chunk".to_string(), + created, + model: model.clone(), + system_fingerprint: system_fingerprint.map(|s| s.to_string()), + choices: vec![ChatStreamChoice { + index: *index, + delta: ChatMessageDelta { + role: Some("assistant".to_string()), + content: None, + tool_calls: None, + reasoning_content: None, + }, + logprobs: None, + finish_reason: Some(final_finish_reason), + matched_stop: matched_stop_value, + }], + usage: None, + }; + + let sse_chunk = serde_json::to_string(&finish_chunk) + .map_err(|e| format!("Failed to serialize finish chunk: {}", e))?; + tx.send(Ok(Bytes::from(format!("data: {}\n\n", sse_chunk)))) + .map_err(|_| "Failed to send finish chunk".to_string())?; + } + + // Phase 5: Usage chunk + if let Some(stream_opts) = stream_options { + if stream_opts.include_usage.unwrap_or(false) { + let total_prompt: u32 = prompt_tokens.values().sum(); + let total_completion: u32 = completion_tokens.values().sum(); + + let usage_chunk = ChatCompletionStreamResponse { + id: request_id.clone(), + object: "chat.completion.chunk".to_string(), + created, + model: model.clone(), + system_fingerprint: system_fingerprint.map(|s| s.to_string()), + choices: vec![], + usage: Some(Usage { + prompt_tokens: total_prompt, + completion_tokens: total_completion, + total_tokens: total_prompt + total_completion, + completion_tokens_details: None, + }), + }; + + let sse_chunk = serde_json::to_string(&usage_chunk) + .map_err(|e| format!("Failed to serialize usage chunk: {}", e))?; + tx.send(Ok(Bytes::from(format!("data: {}\n\n", sse_chunk)))) + .map_err(|_| "Failed to send usage chunk".to_string())?; + } + } + + Ok(()) + } + + /// Process dual streaming chunks (prefill + decode) - PD mode + pub async fn process_dual_streaming_chunks( + &self, + mut prefill_stream: Streaming, + decode_stream: Streaming, + dispatch: context::DispatchMetadata, + stop_params: (Option, Option>, bool, bool), + original_request: ChatCompletionRequest, + tx: &UnboundedSender>, + ) -> Result<(), String> { + // Phase 1.5: Collect input_logprobs from prefill stream if requested + if original_request.logprobs { + while let Some(response) = prefill_stream.next().await { + let gen_response = response.map_err(|e| format!("Prefill stream error: {}", e))?; + match gen_response.response { + Some(proto::generate_response::Response::Complete(_complete)) => { + // Input logprobs collected but not yet used in streaming + // (OpenAI spec doesn't require prompt logprobs in streaming responses) + break; + } + Some(proto::generate_response::Response::Error(error)) => { + return Err(format!("Prefill error: {}", error.message)); + } + _ => continue, + } + } + } + + // Phase 2-5: Process decode stream (same as single mode) + self.process_streaming_chunks(decode_stream, dispatch, stop_params, original_request, tx) + .await + } + + // TODO(generate): Add streaming generate handler + // + // pub async fn process_streaming_generate( + // self: Arc, + // execution_result: context::ExecutionResult, + // generate_request: GenerateRequest, + // dispatch: context::DispatchMetadata, + // ) -> axum::response::Response { + // // Similar to process_streaming_response but: + // // - No tool parsing + // // - No reasoning parsing + // // - Simpler chunk format (just text + finish_reason + logprobs) + // // - Extract stop params from generate_request.sampling_params + // // - Use same per-index stop decoder logic + // // - Emit SSE chunks with format similar to chat but without delta.tool_calls + // // Reference: router.rs:422-595 + // } + + // ======================================================================== + // Helper Methods + // ======================================================================== + + /// Process a chunk of tokens through the stop decoder + fn process_chunk_tokens( + stop_decoder: &mut StopSequenceDecoder, + token_ids: &[u32], + ) -> (String, bool) { + let mut chunk_text = String::new(); + + for &token_id in token_ids { + match stop_decoder.process_token(token_id).unwrap_or_else(|e| { + debug!( + "Error processing token {}: {}. Treating as Held.", + token_id, e + ); + SequenceDecoderOutput::Held + }) { + SequenceDecoderOutput::Text(text) => { + chunk_text.push_str(&text); + } + SequenceDecoderOutput::StoppedWithText(text) => { + chunk_text.push_str(&text); + return (chunk_text, true); + } + SequenceDecoderOutput::Stopped => { + return (chunk_text, true); + } + SequenceDecoderOutput::Held => {} + } + } + (chunk_text, false) + } + + /// Helper: Process reasoning content in streaming mode + #[allow(clippy::too_many_arguments)] + fn process_reasoning_stream( + &self, + delta: &str, + index: u32, + reasoning_parsers: &mut HashMap>>>, + request_id: &str, + model: &str, + created: u64, + system_fingerprint: Option<&str>, + ) -> (String, Option, bool) { + // Get or create parser for this index + reasoning_parsers.entry(index).or_insert_with(|| { + utils::get_reasoning_parser( + &self.reasoning_parser_factory, + self.configured_reasoning_parser.as_ref(), + model, + ) + }); + + if let Some(pooled_parser) = reasoning_parsers.get(&index) { + let (parse_result, in_reasoning) = { + let mut parser = pooled_parser.lock().unwrap(); + let result = parser.parse_reasoning_streaming_incremental(delta); + let in_reasoning = parser.is_in_reasoning(); + (result, in_reasoning) + }; + + match parse_result { + Ok(crate::reasoning_parser::ParserResult { + reasoning_text, + normal_text, + }) => { + let chunk = if !reasoning_text.is_empty() { + Some(ChatCompletionStreamResponse { + id: request_id.to_string(), + object: "chat.completion.chunk".to_string(), + created, + model: model.to_string(), + system_fingerprint: system_fingerprint.map(|s| s.to_string()), + choices: vec![ChatStreamChoice { + index, + delta: ChatMessageDelta { + role: Some("assistant".to_string()), + content: None, + tool_calls: None, + reasoning_content: Some(reasoning_text), + }, + logprobs: None, + finish_reason: None, + matched_stop: None, + }], + usage: None, + }) + } else { + None + }; + return (normal_text, chunk, in_reasoning); + } + Err(e) => { + warn!("Reasoning parsing error: {}", e); + } + } + } + + (delta.to_string(), None, false) + } + + /// Helper: Process tool calls in streaming mode + #[allow(clippy::too_many_arguments)] + async fn process_tool_calls_stream( + &self, + delta: &str, + index: u32, + tool_parsers: &mut HashMap>>>, + has_tool_calls: &mut HashMap, + tools: &[Tool], + request_id: &str, + model: &str, + created: u64, + system_fingerprint: Option<&str>, + history_tool_calls_count: usize, + ) -> (bool, Vec) { + let mut chunks = Vec::new(); + + // Get or create parser for this index + tool_parsers.entry(index).or_insert_with(|| { + utils::get_tool_parser( + &self.tool_parser_factory, + self.configured_tool_parser.as_ref(), + model, + ) + }); + + if let Some(pooled_parser) = tool_parsers.get(&index) { + let mut parser = pooled_parser.lock().await; + match parser.parse_incremental(delta, tools).await { + Ok(crate::tool_parser::StreamingParseResult { normal_text, calls }) => { + // Emit normal text if present + if !normal_text.is_empty() { + chunks.push(ChatCompletionStreamResponse { + id: request_id.to_string(), + object: "chat.completion.chunk".to_string(), + created, + model: model.to_string(), + system_fingerprint: system_fingerprint.map(|s| s.to_string()), + choices: vec![ChatStreamChoice { + index, + delta: ChatMessageDelta { + role: Some("assistant".to_string()), + content: Some(normal_text), + tool_calls: None, + reasoning_content: None, + }, + logprobs: None, + finish_reason: None, + matched_stop: None, + }], + usage: None, + }); + } + + // Emit tool call chunks + for tool_call_item in calls { + has_tool_calls.insert(index, true); + + let tool_call_id = if let Some(ref name) = tool_call_item.name { + Some(utils::generate_tool_call_id( + model, + name, + tool_call_item.tool_index, + history_tool_calls_count, + )) + } else { + None + }; + + let tool_call_delta = ToolCallDelta { + index: tool_call_item.tool_index as u32, + id: tool_call_id, + tool_type: if tool_call_item.name.is_some() { + Some("function".to_string()) + } else { + None + }, + function: Some(FunctionCallDelta { + name: tool_call_item.name, + arguments: if !tool_call_item.parameters.is_empty() { + Some(tool_call_item.parameters) + } else { + None + }, + }), + }; + + chunks.push(ChatCompletionStreamResponse { + id: request_id.to_string(), + object: "chat.completion.chunk".to_string(), + created, + model: model.to_string(), + system_fingerprint: system_fingerprint.map(|s| s.to_string()), + choices: vec![ChatStreamChoice { + index, + delta: ChatMessageDelta { + role: Some("assistant".to_string()), + content: None, + tool_calls: Some(vec![tool_call_delta]), + reasoning_content: None, + }, + logprobs: None, + finish_reason: None, + matched_stop: None, + }], + usage: None, + }); + } + + // If we emitted chunks, skip regular content + return (!chunks.is_empty(), chunks); + } + Err(e) => { + error!("Tool call parsing error: {}", e); + } + } + } + + (false, chunks) + } + + /// Format a response as SSE chunk + fn format_sse_chunk(chunk: &ChatCompletionStreamResponse) -> String { + match serde_json::to_string(chunk) { + Ok(json) => format!("data: {}\n\n", json), + Err(e) => { + error!("Failed to serialize SSE chunk: {}", e); + format!("data: {}\n\n", json!({"error": "serialization_failed"})) + } + } + } + + /// Create a content chunk response + fn create_content_chunk( + content: String, + index: u32, + request_id: &str, + model: &str, + created: u64, + system_fingerprint: Option<&str>, + logprobs: Option, + ) -> ChatCompletionStreamResponse { + ChatCompletionStreamResponse { + id: request_id.to_string(), + object: "chat.completion.chunk".to_string(), + created, + model: model.to_string(), + system_fingerprint: system_fingerprint.map(|s| s.to_string()), + choices: vec![ChatStreamChoice { + index, + delta: ChatMessageDelta { + role: Some("assistant".to_string()), + content: Some(content), + tool_calls: None, + reasoning_content: None, + }, + logprobs, + finish_reason: None, + matched_stop: None, + }], + usage: None, + } + } +} + +/// Build SSE response with proper headers +pub fn build_sse_response( + rx: tokio::sync::mpsc::UnboundedReceiver>, +) -> Response { + let stream = UnboundedReceiverStream::new(rx); + let mut response = Response::new(Body::from_stream(stream)); + *response.status_mut() = StatusCode::OK; + response + .headers_mut() + .insert(CONTENT_TYPE, HeaderValue::from_static("text/event-stream")); + response + .headers_mut() + .insert("Cache-Control", HeaderValue::from_static("no-cache")); + response + .headers_mut() + .insert("Connection", HeaderValue::from_static("keep-alive")); + response +} diff --git a/sgl-router/src/routers/grpc/utils.rs b/sgl-router/src/routers/grpc/utils.rs index af8d30544..c03f2bc8d 100644 --- a/sgl-router/src/routers/grpc/utils.rs +++ b/sgl-router/src/routers/grpc/utils.rs @@ -4,8 +4,8 @@ use super::ProcessedMessages; use crate::core::Worker; use crate::grpc_client::{proto, SglangSchedulerClient}; use crate::protocols::spec::{ - ChatCompletionRequest, ChatMessage, FunctionCallResponse, StringOrArray, Tool, ToolCall, - ToolChoice, ToolChoiceValue, + ChatCompletionRequest, ChatLogProbs, ChatLogProbsContent, ChatMessage, FunctionCallResponse, + StringOrArray, Tool, ToolCall, ToolChoice, ToolChoiceValue, TopLogProb, }; use crate::tokenizer::chat_template::{ChatTemplateContentFormat, ChatTemplateParams}; use crate::tokenizer::traits::Tokenizer; @@ -736,6 +736,79 @@ pub fn get_tool_parser( } } +/// Convert proto::OutputLogProbs to OpenAI ChatLogProbs format +/// +/// This function decodes token IDs using the tokenizer and builds the logprobs structure +/// expected by the OpenAI API format. +pub fn convert_proto_to_openai_logprobs( + proto_logprobs: &proto::OutputLogProbs, + tokenizer: &Arc, +) -> Result { + let mut content_items = Vec::new(); + + // Decode token IDs to text (always with skip_special_tokens=false for logprobs) + let token_texts: Vec = proto_logprobs + .token_ids + .iter() + .map(|&token_id| { + tokenizer + .decode(&[token_id as u32], false) + .unwrap_or_else(|_| format!("", token_id)) + }) + .collect(); + + // Build ChatLogProbsContent for each token (consume iterator to avoid clones) + for (i, (&logprob, token_text)) in proto_logprobs + .token_logprobs + .iter() + .zip(token_texts.into_iter()) + .enumerate() + { + let bytes = Some(token_text.as_bytes().to_vec()); + + // Build top_logprobs for this position + let mut top_logprobs = Vec::new(); + if let Some(top_logprobs_entry) = proto_logprobs.top_logprobs.get(i) { + // Decode top token IDs (always with skip_special_tokens=false) + let top_token_texts: Vec = top_logprobs_entry + .token_ids + .iter() + .map(|&tid| { + tokenizer + .decode(&[tid as u32], false) + .unwrap_or_else(|_| format!("", tid)) + }) + .collect(); + + for (j, (&top_logprob, &_top_token_id)) in top_logprobs_entry + .values + .iter() + .zip(top_logprobs_entry.token_ids.iter()) + .enumerate() + { + if let Some(top_token_text) = top_token_texts.get(j) { + top_logprobs.push(TopLogProb { + token: top_token_text.clone(), + logprob: top_logprob, + bytes: Some(top_token_text.as_bytes().to_vec()), + }); + } + } + } + + content_items.push(ChatLogProbsContent { + token: token_text, + logprob, + bytes, + top_logprobs, + }); + } + + Ok(ChatLogProbs::Detailed { + content: (!content_items.is_empty()).then_some(content_items), + }) +} + #[cfg(test)] mod tests { use super::*;