From d736e0b65e0f7d0272de3fa4a5c911c1bc1ad3a9 Mon Sep 17 00:00:00 2001 From: Simo Lin Date: Sat, 4 Oct 2025 09:58:28 -0400 Subject: [PATCH] [router] add grpc router pd mode for chat and generate (#11140) --- python/pyproject.toml | 4 +- .../srt/entrypoints/grpc_request_manager.py | 8 +- python/sglang/srt/entrypoints/grpc_server.py | 104 +- sgl-router/src/core/worker.rs | 31 +- sgl-router/src/core/worker_builder.rs | 19 + .../src/grpc_client/sglang_scheduler.rs | 6 + sgl-router/src/protocols/spec.rs | 6 + sgl-router/src/routers/grpc/mod.rs | 12 + sgl-router/src/routers/grpc/pd_router.rs | 1987 ++++++++++++++++- sgl-router/src/routers/grpc/router.rs | 1229 +++------- sgl-router/src/routers/grpc/utils.rs | 843 +++++++ 11 files changed, 3169 insertions(+), 1080 deletions(-) create mode 100644 sgl-router/src/routers/grpc/utils.rs diff --git a/python/pyproject.toml b/python/pyproject.toml index d112583bf..9ee96f739 100755 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -67,8 +67,8 @@ dependencies = [ "uvicorn", "uvloop", "xgrammar==0.1.24", - "grpcio==1.74.0", # keep it align with compile_proto.py - "grpcio-tools==1.74.0" # keep it align with compile_proto.py + "grpcio==1.75.1", # keep it align with compile_proto.py + "grpcio-tools==1.75.1" # keep it align with compile_proto.py ] [project.optional-dependencies] diff --git a/python/sglang/srt/entrypoints/grpc_request_manager.py b/python/sglang/srt/entrypoints/grpc_request_manager.py index e2eb541a4..79f730aea 100644 --- a/python/sglang/srt/entrypoints/grpc_request_manager.py +++ b/python/sglang/srt/entrypoints/grpc_request_manager.py @@ -19,7 +19,6 @@ import grpc import zmq import zmq.asyncio -from sglang.srt.managers.disagg_service import start_disagg_service from sglang.srt.managers.io_struct import ( AbortReq, BatchEmbeddingOutput, @@ -111,6 +110,7 @@ class GrpcRequestManager: self, server_args: ServerArgs, port_args: PortArgs, + bootstrap_server=None, ): """Initialize the gRPC request manager.""" self.server_args = server_args @@ -147,8 +147,8 @@ class GrpcRequestManager: self.crash_dump_request_list = [] self.crash_dump_performed = False - # Bootstrap server for disaggregation mode - self.bootstrap_server = start_disagg_service(server_args) + # Bootstrap server (passed from serve_grpc, not started here) + self.bootstrap_server = bootstrap_server logger.info( f"GrpcRequestManager initialized with ZMQ IPC: " @@ -157,7 +157,7 @@ class GrpcRequestManager: ) if self.bootstrap_server: logger.info( - f"Bootstrap server started for disaggregation mode: " + f"Bootstrap server initialized for disaggregation mode: " f"{server_args.disaggregation_mode}" ) diff --git a/python/sglang/srt/entrypoints/grpc_server.py b/python/sglang/srt/entrypoints/grpc_server.py index c143158bb..90ac9ca5f 100644 --- a/python/sglang/srt/entrypoints/grpc_server.py +++ b/python/sglang/srt/entrypoints/grpc_server.py @@ -16,11 +16,13 @@ from typing import AsyncIterator, Dict, Optional, Tuple import grpc from grpc_reflection.v1alpha import reflection +from sglang.srt.disaggregation.utils import FAKE_BOOTSTRAP_HOST, DisaggregationMode from sglang.srt.entrypoints.grpc_request_manager import GrpcRequestManager from sglang.srt.grpc import sglang_scheduler_pb2, sglang_scheduler_pb2_grpc from sglang.srt.managers.data_parallel_controller import ( run_data_parallel_controller_process, ) +from sglang.srt.managers.disagg_service import start_disagg_service from sglang.srt.managers.io_struct import ( TokenizedEmbeddingReqInput, TokenizedGenerateReqInput, @@ -331,6 +333,10 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer) token_ids_logprob=None, ) + if self.server_args.disaggregation_mode != DisaggregationMode.NULL: + health_request.bootstrap_host = FAKE_BOOTSTRAP_HOST + health_request.bootstrap_room = 0 + logger.info(f"Sending health check request to request manager...") # Submit and wait for response @@ -406,6 +412,15 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer) # Convert sampling params sampling_params = self._convert_sampling_params(grpc_req.sampling_params) + # Extract disaggregated params if present + bootstrap_host = None + bootstrap_port = None + bootstrap_room = None + if grpc_req.HasField("disaggregated_params"): + bootstrap_host = grpc_req.disaggregated_params.bootstrap_host or None + bootstrap_port = grpc_req.disaggregated_params.bootstrap_port or None + bootstrap_room = grpc_req.disaggregated_params.bootstrap_room or None + # Create request return TokenizedGenerateReqInput( rid=grpc_req.request_id, @@ -425,6 +440,9 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer) token_ids_logprob=( list(grpc_req.token_ids_logprob) if grpc_req.token_ids_logprob else None ), + bootstrap_host=bootstrap_host, + bootstrap_port=bootstrap_port, + bootstrap_room=bootstrap_room, ) def _convert_embed_request( @@ -659,6 +677,16 @@ async def serve_grpc( ): """Start the standalone gRPC server with integrated scheduler.""" + # Start bootstrap server BEFORE launching scheduler processes (only in PREFILL mode) + # This ensures the bootstrap server is ready when prefill schedulers try to register + bootstrap_server = None + if server_args.disaggregation_mode == "prefill": + bootstrap_server = start_disagg_service(server_args) + if bootstrap_server: + logger.info( + f"Bootstrap server started for disaggregation mode on {server_args.host}:{server_args.disaggregation_bootstrap_port}" + ) + # Launch only the scheduler process(es) (no tokenizer/detokenizer needed for gRPC) logger.info("Launching scheduler process(es)...") scheduler_info, port_args, scheduler_procs = _launch_scheduler_process_only( @@ -682,9 +710,11 @@ async def serve_grpc( } # Create request manager with the correct port args + # Note: We pass None for bootstrap_server since it's already started above request_manager = GrpcRequestManager( server_args=server_args, port_args=port_args, + bootstrap_server=bootstrap_server, ) # Create gRPC server @@ -764,79 +794,9 @@ def main(): mp.set_start_method("spawn", force=True) parser = argparse.ArgumentParser(description="SGLang Standalone gRPC Server") - - # Server arguments - parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to bind to") - parser.add_argument("--port", type=int, default=30000, help="gRPC server port") - - # Model arguments - parser.add_argument("--model-path", type=str, required=True, help="Model path") - parser.add_argument("--tokenizer-path", type=str, help="Tokenizer path") - parser.add_argument("--context-length", type=int, help="Context length") - parser.add_argument("--tp-size", type=int, default=1, help="Tensor parallel size") - parser.add_argument("--dp-size", type=int, default=1, help="Data parallel size") - - # Runtime arguments - parser.add_argument( - "--max-running-requests", type=int, default=2048, help="Max concurrent requests" - ) - parser.add_argument( - "--max-total-tokens", type=int, default=1000000, help="Max total tokens" - ) - parser.add_argument( - "--max-prefill-tokens", type=int, default=16384, help="Max prefill tokens" - ) - parser.add_argument( - "--attention-backend", type=str, default="flashinfer", help="Attention backend" - ) - parser.add_argument("--lora-paths", type=str, help="LoRA adapter paths") - - # Logging - parser.add_argument("--log-level", type=str, default="INFO", help="Logging level") - - # Disaggregation mode arguments - parser.add_argument( - "--disaggregation-mode", - type=str, - default="null", - choices=["null", "prefill", "decode"], - help='Only used for PD disaggregation. "prefill" for prefill-only server, and "decode" for decode-only server. If not specified, it is not PD disaggregated', - ) - parser.add_argument( - "--disaggregation-transfer-backend", - type=str, - default="mooncake", - choices=["mooncake", "nixl", "ascend", "fake"], - help="The backend for disaggregation transfer. Default is mooncake.", - ) - parser.add_argument( - "--disaggregation-bootstrap-port", - type=int, - default=8998, - help="Bootstrap server port on the prefill server. Default is 8998.", - ) - + ServerArgs.add_cli_args(parser) args = parser.parse_args() - - # Convert to ServerArgs with gRPC host and port - server_args = ServerArgs( - model_path=args.model_path, - tokenizer_path=args.tokenizer_path or args.model_path, - context_length=args.context_length, - tp_size=args.tp_size, - dp_size=args.dp_size, - max_running_requests=args.max_running_requests, - max_total_tokens=args.max_total_tokens, - max_prefill_tokens=args.max_prefill_tokens, - attention_backend=args.attention_backend, - lora_paths=args.lora_paths.split(",") if args.lora_paths else None, - log_level=args.log_level, - disaggregation_mode=args.disaggregation_mode, - disaggregation_transfer_backend=args.disaggregation_transfer_backend, - disaggregation_bootstrap_port=args.disaggregation_bootstrap_port, - host=args.host, - port=args.port, - ) + server_args = ServerArgs.from_cli_args(args) # Run server asyncio.run( diff --git a/sgl-router/src/core/worker.rs b/sgl-router/src/core/worker.rs index f354a4b76..570244dc7 100644 --- a/sgl-router/src/core/worker.rs +++ b/sgl-router/src/core/worker.rs @@ -31,6 +31,18 @@ pub trait Worker: Send + Sync + fmt::Debug { /// Get the worker's connection mode (HTTP or gRPC) fn connection_mode(&self) -> ConnectionMode; + /// Get the bootstrap hostname for PD mode + /// Returns cached hostname parsed from URL at construction time + fn bootstrap_host(&self) -> &str { + &self.metadata().bootstrap_host + } + + /// Get the bootstrap port for PD mode + /// Returns cached port from WorkerType::Prefill + fn bootstrap_port(&self) -> Option { + self.metadata().bootstrap_port + } + /// Check if the worker is currently healthy fn is_healthy(&self) -> bool; @@ -147,21 +159,6 @@ pub trait Worker: Send + Sync + fmt::Debug { true } - // TODO: - Enhanced Worker Discovery - // The Worker trait should handle async discovery of metadata from the worker itself - // rather than having service discovery or other components query /get_server_info. - // This keeps service discovery decoupled from worker-specific APIs. - // - // Proposed additions: - // - async fn discover_metadata(&mut self) -> Result<(), Error> - // Query /get_server_info and populate metadata labels with model_id, priority, cost, etc. - // - async fn validate_configuration(&self) -> Result<(), Error> - // Ensure worker has required configuration for its mode (e.g., tokenizer for gRPC) - // - Make worker creation async to allow metadata discovery during initialization - // - // This way service discovery just calls router.add_worker() and the worker - // handles its own metadata discovery internally. - /// Get the model ID this worker serves fn model_id(&self) -> &str { self.metadata() @@ -325,6 +322,10 @@ pub struct WorkerMetadata { pub health_config: HealthConfig, /// API key pub api_key: Option, + /// Cached bootstrap hostname (parsed from URL at construction time) + pub bootstrap_host: String, + /// Cached bootstrap port (from WorkerType::Prefill) + pub bootstrap_port: Option, } /// Basic worker implementation diff --git a/sgl-router/src/core/worker_builder.rs b/sgl-router/src/core/worker_builder.rs index 4e156bb42..69a4047b2 100644 --- a/sgl-router/src/core/worker_builder.rs +++ b/sgl-router/src/core/worker_builder.rs @@ -96,12 +96,29 @@ impl BasicWorkerBuilder { /// Build the BasicWorker instance pub fn build(self) -> BasicWorker { + use std::borrow::Cow; use std::sync::{ atomic::{AtomicBool, AtomicUsize}, Arc, }; use tokio::sync::{Mutex, RwLock}; + let url_to_parse = if self.url.contains("://") { + Cow::from(&self.url) + } else { + Cow::from(format!("http://{}", self.url)) + }; + + let bootstrap_host = match url::Url::parse(&url_to_parse) { + Ok(parsed) => parsed.host_str().unwrap_or("localhost").to_string(), + Err(_) => "localhost".to_string(), + }; + + let bootstrap_port = match self.worker_type { + WorkerType::Prefill { bootstrap_port } => bootstrap_port, + _ => None, + }; + let metadata = WorkerMetadata { url: self.url.clone(), api_key: self.api_key, @@ -109,6 +126,8 @@ impl BasicWorkerBuilder { connection_mode: self.connection_mode, labels: self.labels, health_config: self.health_config, + bootstrap_host, + bootstrap_port, }; let grpc_client = Arc::new(RwLock::new( diff --git a/sgl-router/src/grpc_client/sglang_scheduler.rs b/sgl-router/src/grpc_client/sglang_scheduler.rs index 4cb082b53..845c217be 100644 --- a/sgl-router/src/grpc_client/sglang_scheduler.rs +++ b/sgl-router/src/grpc_client/sglang_scheduler.rs @@ -342,6 +342,12 @@ impl SglangSchedulerClient { .map_err(|_| "min_tokens must fit into a 32-bit signed integer".to_string())?; } + // Handle n with conversion + if let Some(n) = p.n { + sampling.n = i32::try_from(n) + .map_err(|_| "n must fit into a 32-bit signed integer".to_string())?; + } + // Handle constraints (exactly one allowed) sampling.constraint = Self::build_single_constraint_from_plain(p)?; diff --git a/sgl-router/src/protocols/spec.rs b/sgl-router/src/protocols/spec.rs index fc4b9854b..2f99811fa 100644 --- a/sgl-router/src/protocols/spec.rs +++ b/sgl-router/src/protocols/spec.rs @@ -2,6 +2,11 @@ use serde::{Deserialize, Serialize}; use serde_json::{to_value, Map, Number, Value}; use std::collections::HashMap; +// Default model value when not specified +fn default_model() -> String { + "unknown".to_string() +} + // # Protocol Specifications // // This module contains all protocol definitions for OpenAI and SGLang APIs. @@ -169,6 +174,7 @@ pub struct ChatCompletionRequest { pub messages: Vec, /// ID of the model to use + #[serde(default = "default_model")] pub model: String, /// Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far diff --git a/sgl-router/src/routers/grpc/mod.rs b/sgl-router/src/routers/grpc/mod.rs index a6a5d8eec..03a1f9ac2 100644 --- a/sgl-router/src/routers/grpc/mod.rs +++ b/sgl-router/src/routers/grpc/mod.rs @@ -1,4 +1,16 @@ //! gRPC router implementations +use crate::grpc_client::proto; +use crate::protocols::spec::StringOrArray; + pub mod pd_router; pub mod router; +pub mod utils; + +/// Processed chat messages ready for gRPC generation +#[derive(Debug)] +pub struct ProcessedMessages { + pub text: String, + pub multimodal_inputs: Option, + pub stop_sequences: Option, +} diff --git a/sgl-router/src/routers/grpc/pd_router.rs b/sgl-router/src/routers/grpc/pd_router.rs index 135260f06..1cadd834b 100644 --- a/sgl-router/src/routers/grpc/pd_router.rs +++ b/sgl-router/src/routers/grpc/pd_router.rs @@ -1,24 +1,48 @@ // PD (Prefill-Decode) gRPC Router Implementation use crate::config::types::RetryConfig; -use crate::core::{WorkerRegistry, WorkerType}; -use crate::metrics::RouterMetrics; +use crate::core::{ConnectionMode, Worker, WorkerRegistry, WorkerType}; +use crate::grpc_client::proto; +use crate::grpc_client::SglangSchedulerClient; use crate::policies::PolicyRegistry; -use crate::reasoning_parser::ReasoningParserFactory; -use crate::routers::RouterTrait; +use crate::protocols::spec::{ + ChatChoice, ChatCompletionMessage, ChatCompletionRequest, ChatCompletionResponse, + ChatCompletionStreamResponse, ChatLogProbs, ChatLogProbsContent, ChatMessageDelta, + ChatStreamChoice, CompletionRequest, EmbeddingRequest, FunctionCallDelta, FunctionCallResponse, + GenerateRequest, InputIds, RerankRequest, ResponsesGetParams, ResponsesRequest, StringOrArray, + Tool, ToolCall, ToolCallDelta, ToolChoice, ToolChoiceValue, TopLogProb, Usage, +}; +use crate::reasoning_parser::{ParserResult, ReasoningParser, ReasoningParserFactory}; +use crate::routers::http::pd_types::generate_room_id; +use crate::routers::{grpc, RouterTrait}; +use crate::server::AppContext; use crate::tokenizer::traits::Tokenizer; -use crate::tool_parser::ToolParserFactory; +use crate::tokenizer::{SequenceDecoderOutput, StopSequenceDecoder}; +use crate::tool_parser::{StreamingParseResult, ToolParser, ToolParserFactory}; use async_trait::async_trait; use axum::{ body::Body, extract::Request, - http::{HeaderMap, StatusCode}, + http::{header, HeaderMap, HeaderValue, StatusCode}, response::{IntoResponse, Response}, + Json, }; +use grpc::utils; +use proto::generate_response::Response::{Chunk, Complete, Error}; +use serde_json::Value; +use std::collections::HashMap; use std::sync::Arc; -use tracing::info; +use std::time::Instant; +use std::time::{SystemTime, UNIX_EPOCH}; +use tokio::sync::mpsc::unbounded_channel; +use tokio::sync::mpsc::UnboundedSender; +use tokio_stream::Stream; +use tokio_stream::StreamExt; +use tracing::{debug, error, warn}; +use uuid::Uuid; /// gRPC PD (Prefill-Decode) router implementation for SGLang +#[derive(Clone)] #[allow(dead_code)] // Fields will be used once implementation is complete pub struct GrpcPDRouter { worker_registry: Arc, @@ -26,7 +50,6 @@ pub struct GrpcPDRouter { tokenizer: Arc, reasoning_parser_factory: ReasoningParserFactory, tool_parser_factory: ToolParserFactory, - dp_aware: bool, api_key: Option, retry_config: RetryConfig, @@ -34,7 +57,7 @@ pub struct GrpcPDRouter { impl GrpcPDRouter { /// Create a new gRPC PD router - pub async fn new(ctx: &Arc) -> Result { + pub async fn new(ctx: &Arc) -> Result { // Get registries from context let worker_registry = ctx.worker_registry.clone(); let policy_registry = ctx.policy_registry.clone(); @@ -56,33 +79,6 @@ impl GrpcPDRouter { .ok_or_else(|| "gRPC PD router requires tool parser factory".to_string())? .clone(); - // Get prefill and decode workers from registry - they should have been created by WorkerManager - let prefill_workers = worker_registry.get_workers_filtered( - None, // any model - Some(WorkerType::Prefill { - bootstrap_port: None, - }), - Some(crate::core::ConnectionMode::Grpc { port: None }), - false, // include unhealthy workers during initialization - ); - - let decode_workers = worker_registry.get_workers_filtered( - None, // any model - Some(WorkerType::Decode), - Some(crate::core::ConnectionMode::Grpc { port: None }), - false, // include unhealthy workers during initialization - ); - - // Update metrics - RouterMetrics::set_active_workers(prefill_workers.len() + decode_workers.len()); - info!( - "gRPC PD router found {} prefill and {} decode workers in registry", - prefill_workers.len(), - decode_workers.len() - ); - - // No need for local health checkers - WorkerRegistry handles health checking - Ok(GrpcPDRouter { worker_registry, policy_registry, @@ -94,6 +90,1895 @@ impl GrpcPDRouter { retry_config: ctx.router_config.effective_retry_config(), }) } + + /// Select a prefill-decode worker pair using load balancing policies + async fn select_pd_pair( + &self, + request_text: Option<&str>, + model_id: Option<&str>, + ) -> Result<(Arc, Arc), String> { + let effective_model_id = if !self.dp_aware { None } else { model_id }; + + debug!( + "Selecting PD pair: dp_aware={}, model_id={:?}, effective_model_id={:?}", + self.dp_aware, model_id, effective_model_id + ); + + // Get prefill workers + let prefill_workers = if let Some(model) = effective_model_id { + self.worker_registry + .get_by_model_fast(model) + .into_iter() + .filter(|w| matches!(w.worker_type(), WorkerType::Prefill { .. })) + .collect() + } else { + self.worker_registry.get_workers_filtered( + None, + Some(WorkerType::Prefill { + bootstrap_port: None, + }), + Some(ConnectionMode::Grpc { port: None }), + true, // only healthy workers + ) + }; + + // Get decode workers + let decode_workers = if let Some(model) = effective_model_id { + self.worker_registry + .get_by_model_fast(model) + .into_iter() + .filter(|w| matches!(w.worker_type(), WorkerType::Decode)) + .collect() + } else { + self.worker_registry.get_workers_filtered( + None, + Some(WorkerType::Decode), + Some(ConnectionMode::Grpc { port: None }), + true, // only healthy workers + ) + }; + + if prefill_workers.is_empty() { + return Err("No healthy prefill workers available".to_string()); + } + if decode_workers.is_empty() { + return Err("No healthy decode workers available".to_string()); + } + + debug!( + "Found {} prefill workers and {} decode workers", + prefill_workers.len(), + decode_workers.len() + ); + + let prefill_policy = self.policy_registry.get_prefill_policy(); + let decode_policy = self.policy_registry.get_decode_policy(); + + let prefill_idx = prefill_policy + .select_worker(&prefill_workers, request_text) + .ok_or_else(|| "Failed to select prefill worker".to_string())?; + + let decode_idx = decode_policy + .select_worker(&decode_workers, request_text) + .ok_or_else(|| "Failed to select decode worker".to_string())?; + + let prefill = prefill_workers[prefill_idx].clone(); + let decode = decode_workers[decode_idx].clone(); + + debug!( + "Selected PD pair: prefill={}, decode={}", + prefill.url(), + decode.url() + ); + + Ok((prefill, decode)) + } + + /// Main route_generate implementation with PD dual dispatch + async fn route_generate_impl( + &self, + _headers: Option<&HeaderMap>, + body: &GenerateRequest, + model_id: Option<&str>, + ) -> Response { + debug!( + "Processing generate request for model: {:?} (PD mode)", + model_id + ); + + // Step 1: Resolve input (text or input_ids) + let (original_text, token_ids) = match self.resolve_generate_input(body) { + Ok(res) => res, + Err(msg) => { + error!("Invalid generate request: {}", msg); + return (StatusCode::BAD_REQUEST, msg).into_response(); + } + }; + + debug!("Resolved input with {} tokens", token_ids.len()); + + // Step 2: Select prefill-decode worker pair + let (prefill_worker, decode_worker) = match self + .select_pd_pair(original_text.as_deref(), model_id) + .await + { + Ok(pair) => pair, + Err(e) => { + warn!("Failed to select PD worker pair: {}", e); + return (StatusCode::SERVICE_UNAVAILABLE, e).into_response(); + } + }; + + debug!( + "Selected PD pair: prefill={}, decode={}", + prefill_worker.url(), + decode_worker.url() + ); + + // Step 3: Get gRPC clients for both workers + let prefill_client = match utils::get_grpc_client_from_worker(&prefill_worker).await { + Ok(client) => client, + Err(response) => return response, + }; + + let decode_client = match utils::get_grpc_client_from_worker(&decode_worker).await { + Ok(client) => client, + Err(response) => return response, + }; + + // Step 4: Build the gRPC request + let request_id = body + .rid + .clone() + .unwrap_or_else(|| format!("gen-{}", Uuid::new_v4())); + + let mut request = match prefill_client.build_plain_generate_request( + request_id.clone(), + body, + original_text.clone(), + token_ids, + ) { + Ok(req) => req, + Err(e) => { + error!("Failed to build generate request: {}", e); + return (StatusCode::BAD_REQUEST, e).into_response(); + } + }; + + // Step 5: Inject bootstrap metadata + if let Err(e) = Self::inject_bootstrap_metadata(&mut request, &*prefill_worker) { + error!("Failed to inject bootstrap metadata: {}", e); + return (StatusCode::INTERNAL_SERVER_ERROR, e).into_response(); + } + + // Step 6: Get weight version for response metadata + let weight_version = decode_worker + .metadata() + .labels + .get("weight_version") + .cloned() + .unwrap_or_else(|| "default".to_string()); + + // Step 7: Handle streaming vs non-streaming + if body.stream { + self.handle_streaming_generate( + prefill_client, + decode_client, + request, + body, + request_id, + weight_version, + ) + .await + } else { + self.handle_non_streaming_generate( + prefill_client, + decode_client, + request, + body, + request_id, + weight_version, + ) + .await + } + } + + /// Inject bootstrap metadata into a protobuf GenerateRequest + fn inject_bootstrap_metadata( + request: &mut proto::GenerateRequest, + prefill_worker: &dyn Worker, + ) -> Result<(), String> { + let hostname = prefill_worker.bootstrap_host(); + let bootstrap_port = prefill_worker.bootstrap_port().unwrap_or(8998); + + let room_id = generate_room_id(); + + // Create DisaggregatedParams + let disagg_params = proto::DisaggregatedParams { + bootstrap_host: hostname.to_string(), + bootstrap_port: bootstrap_port as i32, + bootstrap_room: room_id as i32, + }; + + // Inject metadata + request.disaggregated_params = Some(disagg_params); + + debug!( + "Injected bootstrap metadata: host={}, port={}, room={}", + hostname, bootstrap_port, room_id + ); + + Ok(()) + } + + /// Main route_chat implementation with PD dual dispatch + async fn route_chat_impl( + &self, + _headers: Option<&HeaderMap>, + body: &ChatCompletionRequest, + model_id: Option<&str>, + ) -> Response { + debug!( + "Processing chat completion request for model: {:?} (PD mode)", + model_id + ); + + // Step 1: Filter tools if needed for allowed_tools or specific function + let body_ref = utils::filter_tools_for_request(body); + + // Step 2: Process messages and apply chat template + let processed_messages = match utils::process_chat_messages(&body_ref, &*self.tokenizer) { + Ok(msgs) => msgs, + Err(e) => { + error!("Failed to process chat messages: {}", e); + return (StatusCode::BAD_REQUEST, e.to_string()).into_response(); + } + }; + + // Step 3: Tokenize the processed text + let encoding = match self.tokenizer.encode(&processed_messages.text) { + Ok(encoding) => encoding, + Err(e) => { + error!("Tokenization failed: {}", e); + return ( + StatusCode::INTERNAL_SERVER_ERROR, + format!("Tokenization failed: {}", e), + ) + .into_response(); + } + }; + + // Step 4: Build tool constraints if needed + // body_ref already has filtered tools if needed + let tool_call_constraint = body_ref.tools.as_ref().and_then(|tools| { + utils::generate_tool_constraints(tools, &body.tool_choice, &body.model) + }); + + let token_ids = encoding.token_ids().to_vec(); + debug!("Tokenized {} tokens from input", token_ids.len()); + + // Step 5: Select prefill-decode worker pair + let (prefill_worker, decode_worker) = match self + .select_pd_pair(Some(&processed_messages.text), model_id) + .await + { + Ok(pair) => pair, + Err(e) => { + warn!("Failed to select PD worker pair: {}", e); + return (StatusCode::SERVICE_UNAVAILABLE, e).into_response(); + } + }; + + debug!( + "Selected PD pair: prefill={}, decode={}", + prefill_worker.url(), + decode_worker.url() + ); + + // Step 6: Get gRPC clients for both workers + let prefill_client = match utils::get_grpc_client_from_worker(&prefill_worker).await { + Ok(client) => client, + Err(response) => return response, + }; + + let decode_client = match utils::get_grpc_client_from_worker(&decode_worker).await { + Ok(client) => client, + Err(response) => return response, + }; + + // Step 7: Build the base gRPC request + let request_id = format!("chatcmpl-{}", Uuid::new_v4()); + let mut request = match prefill_client.build_generate_request( + request_id.clone(), + &body_ref, + processed_messages.text.clone(), + token_ids, + processed_messages.multimodal_inputs, + tool_call_constraint, + ) { + Ok(request) => request, + Err(e) => { + error!("Failed to build gRPC request: {}", e); + return ( + StatusCode::BAD_REQUEST, + format!("Invalid request parameters: {}", e), + ) + .into_response(); + } + }; + + // Step 8: Inject bootstrap metadata into the request + if let Err(e) = Self::inject_bootstrap_metadata(&mut request, &*prefill_worker) { + error!("Failed to inject bootstrap metadata: {}", e); + return (StatusCode::INTERNAL_SERVER_ERROR, e).into_response(); + } + + // Step 9: Handle streaming vs non-streaming + if body.stream { + self.handle_streaming_chat(prefill_client, decode_client, request, body) + .await + } else { + self.handle_non_streaming_chat(prefill_client, decode_client, request, body) + .await + } + } + + /// Resolve the generate input into optional original text and token IDs + fn resolve_generate_input( + &self, + request: &GenerateRequest, + ) -> Result<(Option, Vec), String> { + if let Some(text) = &request.text { + let encoding = self + .tokenizer + .encode(text) + .map_err(|e| format!("Tokenization failed: {}", e))?; + return Ok((Some(text.to_string()), encoding.token_ids().to_vec())); + } + + // Handle input_ids - validate and convert + if let Some(input_ids) = &request.input_ids { + return match input_ids { + InputIds::Single(ids) => ids + .iter() + .map(|&id| u32::try_from(id)) + .collect::, _>>() + .map(|converted| (None, converted)) + .map_err(|_| "input_ids must be non-negative".to_string()), + InputIds::Batch(_) => { + Err("Batch input_ids are not supported in PD mode".to_string()) + } + }; + } + + Err("Either `text` or `input_ids` must be provided".to_string()) + } + + /// Submit request and handle streaming response for chat completions (PD mode) + async fn handle_streaming_chat( + &self, + mut prefill_client: SglangSchedulerClient, + mut decode_client: SglangSchedulerClient, + request: proto::GenerateRequest, + original_request: &ChatCompletionRequest, + ) -> Response { + let request_id = request.request_id.clone(); + let model = original_request.model.clone(); + + // Create channel for SSE streaming + let (tx, rx) = unbounded_channel::>(); + + // Send requests in parallel to both prefill and decode workers + debug!("Starting concurrent streaming requests to prefill and decode workers"); + let prefill_request = request.clone(); + let decode_request = request; + + let (prefill_result, decode_result) = tokio::join!( + prefill_client.generate(prefill_request), + decode_client.generate(decode_request) + ); + + // Get prefill stream + let prefill_stream = match prefill_result { + Ok(s) => s, + Err(e) => { + error!("Failed to start prefill generation: {}", e); + return ( + StatusCode::INTERNAL_SERVER_ERROR, + format!("Prefill worker failed to start: {}", e), + ) + .into_response(); + } + }; + + // Get decode stream - this is what we'll process for output + let decode_stream = match decode_result { + Ok(s) => s, + Err(e) => { + error!("Failed to start decode generation: {}", e); + return ( + StatusCode::INTERNAL_SERVER_ERROR, + format!("Decode worker failed to start: {}", e), + ) + .into_response(); + } + }; + + let stop_params = ( + original_request.stop.clone(), + original_request.stop_token_ids.clone(), + original_request.skip_special_tokens, + original_request.no_stop_trim, + ); + + // Spawn processing task for both streams + let self_clone = self.clone(); + let original_request_clone = original_request.clone(); + tokio::spawn(async move { + let result = Self::process_dual_streaming_chunks( + &self_clone, + prefill_stream, + decode_stream, + request_id, + model, + stop_params, + original_request_clone, + &tx, + ) + .await; + + if let Err(e) = result { + let error_chunk = format!( + "data: {}\n\n", + serde_json::json!({ + "error": { + "message": e, + "type": "internal_error" + } + }) + ); + let _ = tx.send(Ok(bytes::Bytes::from(error_chunk))); + } + + // Send DONE marker + let _ = tx.send(Ok(bytes::Bytes::from("data: [DONE]\n\n"))); + }); + + // Create response with SSE headers + let stream = tokio_stream::wrappers::UnboundedReceiverStream::new(rx); + let mut response = Response::new(Body::from_stream(stream)); + *response.status_mut() = StatusCode::OK; + response.headers_mut().insert( + header::CONTENT_TYPE, + HeaderValue::from_static("text/event-stream"), + ); + response + .headers_mut() + .insert("Cache-Control", HeaderValue::from_static("no-cache")); + response + .headers_mut() + .insert("Connection", HeaderValue::from_static("keep-alive")); + response + } + + /// Submit request and handle streaming response for generate endpoint (PD mode) + async fn handle_streaming_generate( + &self, + mut prefill_client: SglangSchedulerClient, + mut decode_client: SglangSchedulerClient, + request: proto::GenerateRequest, + original_request: &GenerateRequest, + request_id: String, + weight_version: String, + ) -> Response { + // Create channel for SSE streaming + let (tx, rx) = unbounded_channel::>(); + + // Send requests in parallel to both prefill and decode workers + debug!("Starting concurrent streaming generate requests to prefill and decode workers"); + let prefill_request = request.clone(); + let decode_request = request; + + let (prefill_result, decode_result) = tokio::join!( + prefill_client.generate(prefill_request), + decode_client.generate(decode_request) + ); + + // Get prefill stream (for input_logprobs if needed) + let prefill_stream = match prefill_result { + Ok(s) => s, + Err(e) => { + error!("Failed to start prefill generation: {}", e); + return ( + StatusCode::INTERNAL_SERVER_ERROR, + format!("Prefill worker failed to start: {}", e), + ) + .into_response(); + } + }; + + // Get decode stream - this is what we'll process for output + let decode_stream = match decode_result { + Ok(s) => s, + Err(e) => { + error!("Failed to start decode generation: {}", e); + return ( + StatusCode::INTERNAL_SERVER_ERROR, + format!("Decode worker failed to start: {}", e), + ) + .into_response(); + } + }; + + // Spawn processing task for both streams + let tokenizer = self.tokenizer.clone(); + let return_logprob = original_request.return_logprob; + tokio::spawn(async move { + let result = Self::process_generate_streaming( + tokenizer, + prefill_stream, + decode_stream, + request_id, + weight_version, + return_logprob, + &tx, + ) + .await; + + if let Err(e) = result { + let error_chunk = format!( + "data: {}\n\n", + serde_json::json!({ + "error": { + "message": e, + "type": "internal_error" + } + }) + ); + let _ = tx.send(Ok(bytes::Bytes::from(error_chunk))); + } + + // Send DONE marker + let _ = tx.send(Ok(bytes::Bytes::from("data: [DONE]\n\n"))); + }); + + // Create response with SSE headers + let stream = tokio_stream::wrappers::UnboundedReceiverStream::new(rx); + let mut response = Response::new(Body::from_stream(stream)); + *response.status_mut() = StatusCode::OK; + response.headers_mut().insert( + header::CONTENT_TYPE, + HeaderValue::from_static("text/event-stream"), + ); + response + .headers_mut() + .insert("Cache-Control", HeaderValue::from_static("no-cache")); + response + .headers_mut() + .insert("Connection", HeaderValue::from_static("keep-alive")); + response + } + + /// Process generate streaming (simplified - no tool calls or reasoning) + #[allow(clippy::too_many_arguments)] + async fn process_generate_streaming( + tokenizer: Arc, + mut prefill_stream: impl Stream> + Unpin, + mut decode_stream: impl Stream> + Unpin, + request_id: String, + weight_version: String, + include_logprobs: bool, + tx: &UnboundedSender>, + ) -> Result<(), String> { + let start_time = Instant::now(); + + // Phase 1: Collect input_logprobs from prefill stream if requested + // TODO: Store and emit input_logprobs when implementing prompt logprobs in streaming + if include_logprobs { + while let Some(response) = prefill_stream.next().await { + let gen_response = response.map_err(|e| format!("Prefill stream error: {}", e))?; + match gen_response.response { + Some(Complete(_complete)) => { + // Input logprobs collected but not yet used in streaming + break; + } + Some(Error(error)) => { + return Err(format!("Prefill error: {}", error.message)); + } + _ => continue, + } + } + } + + // Phase 2: Main streaming loop (decode stream) + // Track state per index for n>1 case + let mut accumulated_texts: HashMap = HashMap::new(); + let mut completion_tokens_map: HashMap = HashMap::new(); + let mut current_index: u32 = 0; + + while let Some(response) = decode_stream.next().await { + let gen_response = response.map_err(|e| format!("Decode stream error: {}", e))?; + + match gen_response.response { + Some(Chunk(chunk)) => { + // Use our tracked index instead of chunk.index (PD backend bug workaround) + let index = current_index; + debug!( + "Received chunk with backend_index={}, using_index={}, tokens={:?}", + chunk.index, index, chunk.token_ids + ); + + let completion_tokens = completion_tokens_map.entry(index).or_insert(0); + *completion_tokens += chunk.token_ids.len() as u32; + + let chunk_text = tokenizer.decode(&chunk.token_ids, true).unwrap_or_default(); + + let accumulated_text = accumulated_texts.entry(index).or_default(); + accumulated_text.push_str(&chunk_text); + + let index_id = format!("{}-{}", request_id, index); + + let chunk_response = serde_json::json!({ + "text": accumulated_text.clone(), + "output_ids": chunk.token_ids, + "meta_info": { + "id": index_id, + "finish_reason": null, + "prompt_tokens": chunk.prompt_tokens, + "weight_version": weight_version, + "completion_tokens": *completion_tokens, + "cached_tokens": chunk.cached_tokens + }, + "index": index + }); + + let sse_chunk = format!( + "data: {}\n\n", + serde_json::to_string(&chunk_response).unwrap() + ); + tx.send(Ok(bytes::Bytes::from(sse_chunk))) + .map_err(|_| "Failed to send chunk".to_string())?; + } + Some(Complete(complete)) => { + let index = current_index; + debug!( + "Received Complete with backend_index={}, using_index={}, finish_reason={}", + complete.index, index, complete.finish_reason + ); + let accumulated_text = + accumulated_texts.get(&index).cloned().unwrap_or_default(); + let completion_tokens = *completion_tokens_map.get(&index).unwrap_or(&0); + let index_id = format!("{}-{}", request_id, index); + let e2e_latency = start_time.elapsed().as_secs_f64(); + + // Send final chunk with finish_reason (no new tokens in Complete, they were already sent in Chunks) + let finish_response = serde_json::json!({ + "text": accumulated_text, + "output_ids": complete.output_ids[complete.output_ids.len().saturating_sub(1)..].to_vec(), + "meta_info": { + "id": index_id, + "finish_reason": complete.finish_reason, + "prompt_tokens": complete.prompt_tokens, + "weight_version": weight_version, + "completion_tokens": completion_tokens, + "cached_tokens": complete.cached_tokens, + "e2e_latency": e2e_latency + }, + "index": index + }); + + let sse_chunk = format!( + "data: {}\n\n", + serde_json::to_string(&finish_response).unwrap() + ); + tx.send(Ok(bytes::Bytes::from(sse_chunk))) + .map_err(|_| "Failed to send finish chunk".to_string())?; + + // Move to next completion + current_index += 1; + } + Some(Error(error)) => { + return Err(error.message); + } + None => continue, + } + } + + Ok(()) + } + + /// Process dual streaming chunks (prefill + decode) and send SSE events (PD mode) + #[allow(clippy::too_many_arguments)] + async fn process_dual_streaming_chunks( + router: &GrpcPDRouter, + mut prefill_stream: impl Stream> + Unpin, + mut decode_stream: impl Stream> + Unpin, + request_id: String, + model: String, + stop_params: (Option, Option>, bool, bool), + original_request: ChatCompletionRequest, + tx: &UnboundedSender>, + ) -> Result<(), String> { + // Extract request parameters + let separate_reasoning = original_request.separate_reasoning; + let tool_choice = &original_request.tool_choice; + let tools = &original_request.tools; + let history_tool_calls_count = utils::get_history_tool_calls_count(&original_request); + let stream_options = &original_request.stream_options; + + // Phase 1: Initialize state tracking (per-index for n>1 support) + let mut is_firsts: HashMap = HashMap::new(); + let mut stream_buffers: HashMap = HashMap::new(); + let mut finish_reasons: HashMap = HashMap::new(); + let mut matched_stops: HashMap> = HashMap::new(); + let mut prompt_tokens: HashMap = HashMap::new(); + let mut completion_tokens: HashMap = HashMap::new(); + let mut cached_tokens: HashMap = HashMap::new(); + + // Parser state (lazy initialization per index) + type PooledReasoningParser = Arc>>; + let mut reasoning_parsers: HashMap = HashMap::new(); + + type PooledToolParser = Arc>>; + let mut tool_parsers: HashMap = HashMap::new(); + let mut has_tool_calls: HashMap = HashMap::new(); + + // Create stop decoder + let (stop, stop_token_ids, skip_special_tokens, no_stop_trim) = stop_params; + let mut stop_decoder = utils::create_stop_decoder( + &router.tokenizer, + stop.as_ref(), + stop_token_ids.as_ref(), + skip_special_tokens, + no_stop_trim, + ); + + let created = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap_or_default() + .as_secs(); + + // Phase 1.5: Collect input_logprobs from prefill stream if requested + // Note: In PD mode, input_logprobs come from prefill worker + // TODO: Store and emit input_logprobs when implementing prompt logprobs in streaming + if original_request.logprobs { + while let Some(response) = prefill_stream.next().await { + let gen_response = response.map_err(|e| format!("Prefill stream error: {}", e))?; + match gen_response.response { + Some(Complete(_complete)) => { + // Input logprobs collected but not yet used in streaming + // (OpenAI spec doesn't require prompt logprobs in streaming responses) + break; + } + Some(Error(error)) => { + return Err(format!("Prefill error: {}", error.message)); + } + _ => continue, + } + } + } + + // Phase 2: Main streaming loop (decode stream) + while let Some(response) = decode_stream.next().await { + let gen_response = response.map_err(|e| format!("Stream error: {}", e))?; + + match gen_response.response { + Some(Chunk(chunk)) => { + let index = chunk.index; + + // Process tokens through stop decoder + let (chunk_text, _should_stop) = + Self::process_chunk_tokens(&mut stop_decoder, &chunk.token_ids); + + if chunk_text.is_empty() { + continue; + } + + // Process logprobs if present + let choice_logprobs = if let Some(ref proto_logprobs) = chunk.output_logprobs { + match router.convert_proto_to_openai_logprobs(proto_logprobs) { + Ok(logprobs) => Some(logprobs), + Err(e) => { + warn!("Failed to process logprobs: {}", e); + None + } + } + } else { + None + }; + + // Initialize stream buffer if first time + let stream_buffer = stream_buffers.entry(index).or_default(); + + // Send first chunk with role + if is_firsts.get(&index).copied().unwrap_or(true) { + let first_chunk = ChatCompletionStreamResponse { + id: request_id.clone(), + object: "chat.completion.chunk".to_string(), + created, + model: model.clone(), + system_fingerprint: None, + choices: vec![ChatStreamChoice { + index, + delta: ChatMessageDelta { + role: Some("assistant".to_string()), + content: None, + tool_calls: None, + reasoning_content: None, + }, + logprobs: None, + finish_reason: None, + matched_stop: None, + }], + usage: None, + }; + tx.send(Ok(bytes::Bytes::from(Self::format_sse_chunk(&first_chunk)))) + .map_err(|_| "Failed to send first chunk".to_string())?; + is_firsts.insert(index, false); + } + + // Calculate delta + let mut delta = chunk_text; + stream_buffer.push_str(&delta); + + // Reasoning content handling + if separate_reasoning { + let (normal_text, reasoning_chunk) = router.process_reasoning_stream( + &delta, + index, + &mut reasoning_parsers, + &request_id, + &model, + created, + ); + if let Some(chunk) = reasoning_chunk { + tx.send(Ok(bytes::Bytes::from(Self::format_sse_chunk(&chunk)))) + .map_err(|_| "Failed to send reasoning chunk".to_string())?; + } + delta = normal_text; + } + + // Tool call handling + let tool_choice_enabled = + !matches!(tool_choice, Some(ToolChoice::Value(ToolChoiceValue::None))); + + if tool_choice_enabled && tools.is_some() { + let (should_skip, tool_chunks) = router + .process_tool_calls_stream( + &delta, + index, + &mut tool_parsers, + &mut has_tool_calls, + tools.as_ref().unwrap(), + &request_id, + &model, + created, + history_tool_calls_count, + ) + .await; + + for chunk in tool_chunks { + tx.send(Ok(bytes::Bytes::from(Self::format_sse_chunk(&chunk)))) + .map_err(|_| "Failed to send tool call chunk".to_string())?; + } + + if should_skip { + continue; + } + } + + // Regular content emission + if !delta.is_empty() { + let content_chunk = Self::create_content_chunk( + delta, + index, + &request_id, + &model, + created, + choice_logprobs, + ); + tx.send(Ok(bytes::Bytes::from(Self::format_sse_chunk( + &content_chunk, + )))) + .map_err(|_| "Failed to send content chunk".to_string())?; + } + } + Some(Complete(complete)) => { + // Flush any remaining text + if let SequenceDecoderOutput::Text(text) = stop_decoder.flush() { + if !text.is_empty() { + let index = complete.index; + let stream_buffer = stream_buffers.entry(index).or_default(); + stream_buffer.push_str(&text); + + let content_chunk = ChatCompletionStreamResponse { + id: request_id.clone(), + object: "chat.completion.chunk".to_string(), + created, + model: model.clone(), + system_fingerprint: None, + choices: vec![ChatStreamChoice { + index, + delta: ChatMessageDelta { + role: Some("assistant".to_string()), + content: Some(text), + tool_calls: None, + reasoning_content: None, + }, + logprobs: None, + finish_reason: None, + matched_stop: None, + }], + usage: None, + }; + + let sse_chunk = serde_json::to_string(&content_chunk) + .map_err(|e| format!("Failed to serialize content chunk: {}", e))?; + tx.send(Ok(bytes::Bytes::from(format!("data: {}\n\n", sse_chunk)))) + .map_err(|_| "Failed to send flushed content".to_string())?; + } + } + + // Store metadata + let index = complete.index; + prompt_tokens.insert(index, complete.prompt_tokens as u32); + completion_tokens.insert(index, complete.completion_tokens as u32); + cached_tokens.insert(index, complete.cached_tokens as u32); + finish_reasons.insert(index, complete.finish_reason.clone()); + + // Extract matched_stop + let matched_stop_value = match &complete.matched_stop { + Some(proto::generate_complete::MatchedStop::MatchedTokenId(token_id)) => { + Some(Value::Number(serde_json::Number::from(*token_id))) + } + Some(proto::generate_complete::MatchedStop::MatchedStopStr(stop_str)) => { + Some(Value::String(stop_str.clone())) + } + None => None, + }; + matched_stops.insert(index, matched_stop_value); + + break; + } + Some(Error(error)) => { + return Err(error.message); + } + None => continue, + } + } + + // Phase 3: Check unstreamed tool args + for (index, parser) in &tool_parsers { + let parser_guard = parser.lock().await; + if let Some(unstreamed_items) = parser_guard.get_unstreamed_tool_args() { + for tool_call_item in unstreamed_items { + let tool_call_delta = ToolCallDelta { + index: tool_call_item.tool_index as u32, + id: None, + tool_type: None, + function: Some(FunctionCallDelta { + name: None, + arguments: if !tool_call_item.parameters.is_empty() { + Some(tool_call_item.parameters) + } else { + None + }, + }), + }; + + let tool_chunk = ChatCompletionStreamResponse { + id: request_id.clone(), + object: "chat.completion.chunk".to_string(), + created, + model: model.clone(), + system_fingerprint: None, + choices: vec![ChatStreamChoice { + index: *index, + delta: ChatMessageDelta { + role: Some("assistant".to_string()), + content: None, + tool_calls: Some(vec![tool_call_delta]), + reasoning_content: None, + }, + logprobs: None, + finish_reason: None, + matched_stop: None, + }], + usage: None, + }; + + let sse_chunk = serde_json::to_string(&tool_chunk) + .map_err(|e| format!("Failed to serialize tool chunk: {}", e))?; + tx.send(Ok(bytes::Bytes::from(format!("data: {}\n\n", sse_chunk)))) + .map_err(|_| "Failed to send unstreamed tool args".to_string())?; + } + } + } + + // Phase 4: Finish reason chunks + for (index, finish_reason) in finish_reasons.iter() { + let final_finish_reason = + if has_tool_calls.get(index).copied().unwrap_or(false) && finish_reason == "stop" { + "tool_calls".to_string() + } else { + finish_reason.clone() + }; + + let matched_stop_value = matched_stops.get(index).and_then(|v| v.clone()); + + let finish_chunk = ChatCompletionStreamResponse { + id: request_id.clone(), + object: "chat.completion.chunk".to_string(), + created, + model: model.clone(), + system_fingerprint: None, + choices: vec![ChatStreamChoice { + index: *index, + delta: ChatMessageDelta { + role: Some("assistant".to_string()), + content: None, + tool_calls: None, + reasoning_content: None, + }, + logprobs: None, + finish_reason: Some(final_finish_reason), + matched_stop: matched_stop_value, + }], + usage: None, + }; + + let sse_chunk = serde_json::to_string(&finish_chunk) + .map_err(|e| format!("Failed to serialize finish chunk: {}", e))?; + tx.send(Ok(bytes::Bytes::from(format!("data: {}\n\n", sse_chunk)))) + .map_err(|_| "Failed to send finish chunk".to_string())?; + } + + // Phase 5: Usage chunk + if let Some(stream_opts) = stream_options { + if stream_opts.include_usage.unwrap_or(false) { + let total_prompt: u32 = prompt_tokens.values().sum(); + let total_completion: u32 = completion_tokens.values().sum(); + + let usage_chunk = ChatCompletionStreamResponse { + id: request_id.clone(), + object: "chat.completion.chunk".to_string(), + created, + model: model.clone(), + system_fingerprint: None, + choices: vec![], + usage: Some(Usage { + prompt_tokens: total_prompt, + completion_tokens: total_completion, + total_tokens: total_prompt + total_completion, + completion_tokens_details: None, + }), + }; + + let sse_chunk = serde_json::to_string(&usage_chunk) + .map_err(|e| format!("Failed to serialize usage chunk: {}", e))?; + tx.send(Ok(bytes::Bytes::from(format!("data: {}\n\n", sse_chunk)))) + .map_err(|_| "Failed to send usage chunk".to_string())?; + } + } + + Ok(()) + } + + /// Helper: Process reasoning content in streaming mode + fn process_reasoning_stream( + &self, + delta: &str, + index: u32, + reasoning_parsers: &mut HashMap>>>, + request_id: &str, + model: &str, + created: u64, + ) -> (String, Option) { + // Get or create parser for this index + reasoning_parsers + .entry(index) + .or_insert_with(|| self.reasoning_parser_factory.get_pooled(model)); + + if let Some(pooled_parser) = reasoning_parsers.get(&index) { + let parse_result = { + let mut parser = pooled_parser.lock().unwrap(); + parser.parse_reasoning_streaming_incremental(delta) + }; + + match parse_result { + Ok(ParserResult { + reasoning_text, + normal_text, + }) => { + let chunk = if !reasoning_text.is_empty() { + Some(ChatCompletionStreamResponse { + id: request_id.to_string(), + object: "chat.completion.chunk".to_string(), + created, + model: model.to_string(), + system_fingerprint: None, + choices: vec![ChatStreamChoice { + index, + delta: ChatMessageDelta { + role: Some("assistant".to_string()), + content: None, + tool_calls: None, + reasoning_content: Some(reasoning_text), + }, + logprobs: None, + finish_reason: None, + matched_stop: None, + }], + usage: None, + }) + } else { + None + }; + return (normal_text, chunk); + } + Err(e) => { + warn!("Reasoning parsing error: {}", e); + } + } + } + + (delta.to_string(), None) + } + + /// Helper: Process tool calls in streaming mode + #[allow(clippy::too_many_arguments)] + async fn process_tool_calls_stream( + &self, + delta: &str, + index: u32, + tool_parsers: &mut HashMap>>>, + has_tool_calls: &mut HashMap, + tools: &[Tool], + request_id: &str, + model: &str, + created: u64, + history_tool_calls_count: usize, + ) -> (bool, Vec) { + let mut chunks = Vec::new(); + + // Get or create parser for this index + tool_parsers + .entry(index) + .or_insert_with(|| self.tool_parser_factory.get_pooled(model)); + + if let Some(pooled_parser) = tool_parsers.get(&index) { + let mut parser = pooled_parser.lock().await; + match parser.parse_incremental(delta, tools).await { + Ok(StreamingParseResult { normal_text, calls }) => { + // Emit normal text if present + if !normal_text.is_empty() { + chunks.push(ChatCompletionStreamResponse { + id: request_id.to_string(), + object: "chat.completion.chunk".to_string(), + created, + model: model.to_string(), + system_fingerprint: None, + choices: vec![ChatStreamChoice { + index, + delta: ChatMessageDelta { + role: Some("assistant".to_string()), + content: Some(normal_text), + tool_calls: None, + reasoning_content: None, + }, + logprobs: None, + finish_reason: None, + matched_stop: None, + }], + usage: None, + }); + } + + // Emit tool call chunks + for tool_call_item in calls { + has_tool_calls.insert(index, true); + + let tool_call_id = if let Some(ref name) = tool_call_item.name { + Some(utils::generate_tool_call_id( + model, + name, + tool_call_item.tool_index, + history_tool_calls_count, + )) + } else { + None + }; + + let tool_call_delta = ToolCallDelta { + index: tool_call_item.tool_index as u32, + id: tool_call_id, + tool_type: if tool_call_item.name.is_some() { + Some("function".to_string()) + } else { + None + }, + function: Some(FunctionCallDelta { + name: tool_call_item.name, + arguments: if !tool_call_item.parameters.is_empty() { + Some(tool_call_item.parameters) + } else { + None + }, + }), + }; + + chunks.push(ChatCompletionStreamResponse { + id: request_id.to_string(), + object: "chat.completion.chunk".to_string(), + created, + model: model.to_string(), + system_fingerprint: None, + choices: vec![ChatStreamChoice { + index, + delta: ChatMessageDelta { + role: Some("assistant".to_string()), + content: None, + tool_calls: Some(vec![tool_call_delta]), + reasoning_content: None, + }, + logprobs: None, + finish_reason: None, + matched_stop: None, + }], + usage: None, + }); + } + + // If we emitted chunks, skip regular content + return (!chunks.is_empty(), chunks); + } + Err(e) => { + warn!("Tool call parsing error: {}", e); + } + } + } + + (false, chunks) + } + + /// Helper: Create content chunk + fn create_content_chunk( + content: String, + index: u32, + request_id: &str, + model: &str, + created: u64, + logprobs: Option, + ) -> ChatCompletionStreamResponse { + ChatCompletionStreamResponse { + id: request_id.to_string(), + object: "chat.completion.chunk".to_string(), + created, + model: model.to_string(), + system_fingerprint: None, + choices: vec![ChatStreamChoice { + index, + delta: ChatMessageDelta { + role: Some("assistant".to_string()), + content: Some(content), + tool_calls: None, + reasoning_content: None, + }, + logprobs, + finish_reason: None, + matched_stop: None, + }], + usage: None, + } + } + + /// Helper: Format response as SSE chunk + fn format_sse_chunk(response: &ChatCompletionStreamResponse) -> String { + format!( + "data: {}\n\n", + serde_json::to_string(response).unwrap_or_default() + ) + } + + /// Process a chunk of tokens through the stop decoder + fn process_chunk_tokens( + stop_decoder: &mut StopSequenceDecoder, + token_ids: &[u32], + ) -> (String, bool) { + let mut chunk_text = String::new(); + + for &token_id in token_ids { + match stop_decoder.process_token(token_id).unwrap_or_else(|e| { + debug!( + "Error processing token {}: {}. Treating as Held.", + token_id, e + ); + SequenceDecoderOutput::Held + }) { + SequenceDecoderOutput::Text(text) => { + chunk_text.push_str(&text); + } + SequenceDecoderOutput::StoppedWithText(text) => { + chunk_text.push_str(&text); + return (chunk_text, true); + } + SequenceDecoderOutput::Stopped => { + return (chunk_text, true); + } + SequenceDecoderOutput::Held => {} + } + } + (chunk_text, false) + } + + /// Submit request and handle non-streaming response for chat completions (PD mode) + async fn handle_non_streaming_chat( + &self, + mut prefill_client: SglangSchedulerClient, + mut decode_client: SglangSchedulerClient, + request: proto::GenerateRequest, + original_request: &ChatCompletionRequest, + ) -> Response { + // Step 1: Create stop decoder + let mut stop_decoder = utils::create_stop_decoder( + &self.tokenizer, + original_request.stop.as_ref(), + original_request.stop_token_ids.as_ref(), + original_request.skip_special_tokens, + original_request.no_stop_trim, + ); + + // Step 2: Send requests in parallel + debug!("Sending concurrent requests to prefill and decode workers"); + let prefill_request = request.clone(); + let decode_request = request; + + let (prefill_result, decode_result) = tokio::join!( + prefill_client.generate(prefill_request), + decode_client.generate(decode_request) + ); + + // Step 3: Process prefill stream in parallel - if it fails, assume decode fails + let prefill_stream = match prefill_result { + Ok(s) => s, + Err(e) => { + error!("Failed to start prefill generation: {}", e); + return utils::internal_error_message(format!( + "Prefill worker failed to start: {}", + e + )); + } + }; + + let decode_stream = match decode_result { + Ok(s) => s, + Err(e) => { + error!("Failed to start decode generation: {}", e); + return utils::internal_error_message(format!( + "Decode worker failed to start: {}", + e + )); + } + }; + + // Collect prefill response (for input_logprobs if requested) + let prefill_responses = + match utils::collect_stream_responses(prefill_stream, "Prefill").await { + Ok(responses) => responses, + Err(error_response) => return error_response, + }; + + // Extract input_logprobs from prefill response if available + let prefill_input_logprobs = prefill_responses + .first() + .and_then(|r| r.input_logprobs.clone()); + + // Step 4: Process decode stream (collect all responses for n>1 support) + let all_responses = match utils::collect_stream_responses(decode_stream, "Decode").await { + Ok(responses) => responses, + Err(error_response) => return error_response, + }; + + if all_responses.is_empty() { + return utils::internal_error_static("No responses from decode worker"); + } + + // Process each response into a ChatChoice + let history_tool_calls_count = utils::get_history_tool_calls_count(original_request); + let mut choices = Vec::new(); + for (index, complete) in all_responses.iter().enumerate() { + // Merge prefill input_logprobs if available and requested + let mut complete_with_logprobs = complete.clone(); + if prefill_input_logprobs.is_some() && original_request.logprobs { + complete_with_logprobs.input_logprobs = prefill_input_logprobs.clone(); + } + + match self + .process_single_choice( + &complete_with_logprobs, + index, + original_request, + &mut stop_decoder, + history_tool_calls_count, + ) + .await + { + Ok(choice) => choices.push(choice), + Err(e) => { + return utils::internal_error_message(format!( + "Failed to process choice {}: {}", + index, e + )); + } + } + } + + // Aggregate usage information from all responses + let total_prompt_tokens: u32 = all_responses.iter().map(|r| r.prompt_tokens as u32).sum(); + let total_completion_tokens: u32 = all_responses + .iter() + .map(|r| r.completion_tokens as u32) + .sum(); + + let usage = Usage { + prompt_tokens: total_prompt_tokens, + completion_tokens: total_completion_tokens, + total_tokens: total_prompt_tokens + total_completion_tokens, + completion_tokens_details: None, + }; + + // Build final ChatCompletionResponse + let response = ChatCompletionResponse { + id: format!("chatcmpl-{}", Uuid::new_v4()), + object: "chat.completion".to_string(), + created: SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap_or_default() + .as_secs(), + model: original_request.model.clone(), + choices, + usage: Some(usage), + system_fingerprint: None, + }; + + // Serialize and return JSON response + Json(response).into_response() + } + + /// Submit request and handle non-streaming response for generate endpoint (PD mode) + async fn handle_non_streaming_generate( + &self, + mut prefill_client: SglangSchedulerClient, + mut decode_client: SglangSchedulerClient, + request: proto::GenerateRequest, + original_request: &GenerateRequest, + request_id: String, + weight_version: String, + ) -> Response { + use std::time::Instant; + + let start_time = Instant::now(); + + // Send requests in parallel + debug!("Sending concurrent generate requests to prefill and decode workers"); + let prefill_request = request.clone(); + let decode_request = request; + + let (prefill_result, decode_result) = tokio::join!( + prefill_client.generate(prefill_request), + decode_client.generate(decode_request) + ); + + // Process prefill stream + let prefill_stream = match prefill_result { + Ok(s) => s, + Err(e) => { + error!("Failed to start prefill generation: {}", e); + return utils::internal_error_message(format!( + "Prefill worker failed to start: {}", + e + )); + } + }; + + let decode_stream = match decode_result { + Ok(s) => s, + Err(e) => { + error!("Failed to start decode generation: {}", e); + return utils::internal_error_message(format!( + "Decode worker failed to start: {}", + e + )); + } + }; + + // Collect prefill responses + // TODO add logprob for generate + let _prefill_responses = + match utils::collect_stream_responses(prefill_stream, "Prefill").await { + Ok(responses) => responses, + Err(error_response) => return error_response, + }; + + // Collect decode responses + let decode_responses = match utils::collect_stream_responses(decode_stream, "Decode").await + { + Ok(responses) => responses, + Err(error_response) => return error_response, + }; + + if decode_responses.is_empty() { + return utils::internal_error_static("No completion received from decode worker"); + } + + // Create stop decoder from sampling params + let params = original_request.sampling_params.as_ref(); + let mut stop_decoder = utils::create_stop_decoder( + &self.tokenizer, + params.and_then(|p| p.stop.as_ref()), + params.and_then(|p| p.stop_token_ids.as_ref()), + params.and_then(|p| p.skip_special_tokens).unwrap_or(true), + params.and_then(|p| p.no_stop_trim).unwrap_or(false), + ); + + // Process each completion + let mut result_array = Vec::new(); + for mut complete in decode_responses { + stop_decoder.reset(); + + // Process tokens through stop decoder + let outputs = match stop_decoder.process_tokens(&complete.output_ids) { + Ok(outputs) => outputs, + Err(e) => { + return utils::internal_error_message(format!( + "Failed to process tokens: {}", + e + )) + } + }; + + // Accumulate text with early breaks + let mut decoded_text = String::new(); + for output in outputs { + match output { + SequenceDecoderOutput::Text(t) => decoded_text.push_str(&t), + SequenceDecoderOutput::StoppedWithText(t) => { + decoded_text.push_str(&t); + break; + } + SequenceDecoderOutput::Stopped => break, + SequenceDecoderOutput::Held => {} + } + } + + // Flush remaining text + if let SequenceDecoderOutput::Text(t) = stop_decoder.flush() { + decoded_text.push_str(&t); + } + + let output_ids = complete.output_ids.clone(); + + // Build base meta_info + let mut meta_info = serde_json::json!({ + "id": request_id.clone(), + "finish_reason": complete.finish_reason.clone(), + "prompt_tokens": complete.prompt_tokens, + "weight_version": weight_version.clone(), + "completion_tokens": complete.completion_tokens, + "cached_tokens": complete.cached_tokens, + "e2e_latency": start_time.elapsed().as_secs_f64(), + }); + + let meta_obj = meta_info.as_object_mut().unwrap(); + + // Add matched_stop if present + if let Some(matched) = complete.matched_stop.take() { + use proto::generate_complete::MatchedStop; + let matched_value = match matched { + MatchedStop::MatchedTokenId(id) => serde_json::json!(id), + MatchedStop::MatchedStopStr(s) => serde_json::json!(s), + }; + meta_obj.insert("matched_stop".to_string(), matched_value); + } + + result_array.push(serde_json::json!({ + "text": decoded_text, + "output_ids": output_ids, + "meta_info": meta_info, + })); + } + + Json(result_array).into_response() + } + + /// Process a single GenerateComplete response into a ChatChoice + async fn process_single_choice( + &self, + complete: &proto::GenerateComplete, + index: usize, + original_request: &ChatCompletionRequest, + stop_decoder: &mut StopSequenceDecoder, + history_tool_calls_count: usize, + ) -> Result { + stop_decoder.reset(); + // Decode tokens + let outputs = stop_decoder + .process_tokens(&complete.output_ids) + .map_err(|e| format!("Failed to process tokens: {}", e))?; + + // Accumulate text with early breaks + let mut final_text = String::new(); + for output in outputs { + match output { + SequenceDecoderOutput::Text(t) => final_text.push_str(&t), + SequenceDecoderOutput::StoppedWithText(t) => { + final_text.push_str(&t); + break; + } + SequenceDecoderOutput::Stopped => break, + SequenceDecoderOutput::Held => {} + } + } + + // Flush remaining text + if let SequenceDecoderOutput::Text(t) = stop_decoder.flush() { + final_text.push_str(&t); + } + + // Step 1: Handle reasoning content parsing + let mut reasoning_text: Option = None; + let mut processed_text = final_text; + + // Check if reasoning parsing is enabled and separate_reasoning is requested + if original_request.separate_reasoning { + let pooled_parser = self + .reasoning_parser_factory + .get_pooled(&original_request.model); + + let mut parser = pooled_parser + .lock() + .map_err(|e| format!("Failed to acquire reasoning parser lock: {}", e))?; + match parser.detect_and_parse_reasoning(&processed_text) { + Ok(result) => { + if !result.reasoning_text.is_empty() { + reasoning_text = Some(result.reasoning_text); + } + processed_text = result.normal_text; + } + Err(e) => { + return Err(format!("Reasoning parsing error: {}", e)); + } + } + } + + // Step 2: Handle tool call parsing + let mut tool_calls: Option> = None; + + // Check if tool calls should be processed + let tool_choice_enabled = !matches!( + &original_request.tool_choice, + Some(ToolChoice::Value(ToolChoiceValue::None)) + ); + + if tool_choice_enabled && original_request.tools.is_some() { + // Check if JSON schema constraint was used (specific function or required mode) + let used_json_schema = match &original_request.tool_choice { + Some(ToolChoice::Function { .. }) => true, + Some(ToolChoice::Value(ToolChoiceValue::Required)) => true, + Some(ToolChoice::AllowedTools { mode, .. }) => mode == "required", + _ => false, + }; + + if used_json_schema { + (tool_calls, processed_text) = utils::parse_json_schema_response( + &processed_text, + &original_request.tool_choice, + ); + } else { + (tool_calls, processed_text) = self + .parse_tool_calls( + &processed_text, + &original_request.model, + history_tool_calls_count, + ) + .await; + } + } + + // Step 3: Use finish reason directly from proto (already OpenAI-compatible string) + let finish_reason_str = &complete.finish_reason; + + // Override finish reason if we have tool calls + let final_finish_reason_str = if tool_calls.is_some() { + "tool_calls" + } else { + finish_reason_str + }; + + // Extract matched_stop information from proto + let matched_stop = match &complete.matched_stop { + Some(proto::generate_complete::MatchedStop::MatchedTokenId(token_id)) => { + Some(Value::Number(serde_json::Number::from(*token_id))) + } + Some(proto::generate_complete::MatchedStop::MatchedStopStr(stop_str)) => { + Some(Value::String(stop_str.clone())) + } + None => None, + }; + + // Step 4: Convert output logprobs if present + // Note: complete.input_logprobs exists in proto but is not used for chat completions + // (input logprobs are only used in /v1/completions endpoint with echo=true) + let logprobs = if let Some(proto_logprobs) = &complete.output_logprobs { + match self.convert_proto_to_openai_logprobs(proto_logprobs) { + Ok(logprobs) => Some(logprobs), + Err(e) => { + error!("Failed to convert logprobs: {}", e); + None + } + } + } else { + None + }; + + // Step 5: Build ChatCompletionMessage (proper response message type) + let chat_message = ChatCompletionMessage { + role: "assistant".to_string(), + content: if processed_text.is_empty() { + None + } else { + Some(processed_text) + }, + tool_calls, + reasoning_content: reasoning_text, + }; + + // Step 6: Build ChatChoice + let choice = ChatChoice { + index: index as u32, + message: chat_message, + logprobs, + finish_reason: Some(final_finish_reason_str.to_string()), + matched_stop, + hidden_states: None, + }; + + Ok(choice) + } + + /// Parse tool calls using model-specific parser + async fn parse_tool_calls( + &self, + processed_text: &str, + model: &str, + history_tool_calls_count: usize, + ) -> (Option>, String) { + // Get pooled parser for this model + let pooled_parser = self.tool_parser_factory.get_pooled(model); + + // Check format detection first + let can_parse = { + let parser = pooled_parser.lock().await; + parser.detect_format(processed_text) + // Lock is dropped here + }; + + if !can_parse { + return (None, processed_text.to_string()); + } + + // Lock again for async parsing + let result = { + let parser = pooled_parser.lock().await; + parser.parse_complete(processed_text).await + // Lock is dropped here + }; + + match result { + Ok((normal_text, parsed_tool_calls)) => { + if parsed_tool_calls.is_empty() { + return (None, normal_text); + } + + let spec_tool_calls = parsed_tool_calls + .into_iter() + .enumerate() + .map(|(index, tc)| { + // Generate ID for this tool call + let id = utils::generate_tool_call_id( + model, + &tc.function.name, + index, + history_tool_calls_count, + ); + ToolCall { + id, + tool_type: "function".to_string(), + function: FunctionCallResponse { + name: tc.function.name, + arguments: Some( + serde_json::to_string(&tc.function.arguments) + .unwrap_or_else(|_| "{}".to_string()), + ), + }, + } + }) + .collect(); + (Some(spec_tool_calls), normal_text) + } + Err(e) => { + error!("Tool call parsing error: {}", e); + (None, processed_text.to_string()) + } + } + } + + /// Convert proto LogProbs to OpenAI ChatLogProbs format + /// Note: Always decodes with skip_special_tokens=false to show actual tokens generated + fn convert_proto_to_openai_logprobs( + &self, + proto_logprobs: &proto::OutputLogProbs, + ) -> Result { + let mut content_items = Vec::new(); + + // Decode token IDs to text (always with skip_special_tokens=false for logprobs) + let token_texts: Vec = proto_logprobs + .token_ids + .iter() + .map(|&token_id| { + self.tokenizer + .decode(&[token_id as u32], false) + .unwrap_or_else(|_| format!("", token_id)) + }) + .collect(); + + // Build ChatLogProbsContent for each token + for (i, &logprob) in proto_logprobs.token_logprobs.iter().enumerate() { + let token_text = token_texts.get(i).cloned().unwrap_or_default(); + let bytes = Some(token_text.as_bytes().to_vec()); + + // Build top_logprobs for this position + let mut top_logprobs = Vec::new(); + if let Some(top_logprobs_entry) = proto_logprobs.top_logprobs.get(i) { + // Decode top token IDs (always with skip_special_tokens=false) + let top_token_texts: Vec = top_logprobs_entry + .token_ids + .iter() + .map(|&tid| { + self.tokenizer + .decode(&[tid as u32], false) + .unwrap_or_else(|_| format!("", tid)) + }) + .collect(); + + for (j, (&top_logprob, &_top_token_id)) in top_logprobs_entry + .values + .iter() + .zip(top_logprobs_entry.token_ids.iter()) + .enumerate() + { + if let Some(top_token_text) = top_token_texts.get(j) { + top_logprobs.push(TopLogProb { + token: top_token_text.clone(), + logprob: top_logprob, + bytes: Some(top_token_text.as_bytes().to_vec()), + }); + } + } + } + + content_items.push(ChatLogProbsContent { + token: token_text, + logprob, + bytes, + top_logprobs, + }); + } + + Ok(ChatLogProbs::Detailed { + content: (!content_items.is_empty()).then_some(content_items), + }) + } } impl std::fmt::Debug for GrpcPDRouter { @@ -103,13 +1988,13 @@ impl std::fmt::Debug for GrpcPDRouter { Some(WorkerType::Prefill { bootstrap_port: None, }), - Some(crate::core::ConnectionMode::Grpc { port: None }), + Some(ConnectionMode::Grpc { port: None }), false, ); let decode_workers = self.worker_registry.get_workers_filtered( None, Some(WorkerType::Decode), - Some(crate::core::ConnectionMode::Grpc { port: None }), + Some(ConnectionMode::Grpc { port: None }), false, ); f.debug_struct("GrpcPDRouter") @@ -149,26 +2034,26 @@ impl RouterTrait for GrpcPDRouter { async fn route_generate( &self, - _headers: Option<&HeaderMap>, - _body: &crate::protocols::spec::GenerateRequest, - _model_id: Option<&str>, + headers: Option<&HeaderMap>, + body: &GenerateRequest, + model_id: Option<&str>, ) -> Response { - (StatusCode::NOT_IMPLEMENTED).into_response() + self.route_generate_impl(headers, body, model_id).await } async fn route_chat( &self, - _headers: Option<&HeaderMap>, - _body: &crate::protocols::spec::ChatCompletionRequest, - _model_id: Option<&str>, + headers: Option<&HeaderMap>, + body: &ChatCompletionRequest, + model_id: Option<&str>, ) -> Response { - (StatusCode::NOT_IMPLEMENTED).into_response() + self.route_chat_impl(headers, body, model_id).await } async fn route_completion( &self, _headers: Option<&HeaderMap>, - _body: &crate::protocols::spec::CompletionRequest, + _body: &CompletionRequest, _model_id: Option<&str>, ) -> Response { (StatusCode::NOT_IMPLEMENTED).into_response() @@ -177,7 +2062,7 @@ impl RouterTrait for GrpcPDRouter { async fn route_responses( &self, _headers: Option<&HeaderMap>, - _body: &crate::protocols::spec::ResponsesRequest, + _body: &ResponsesRequest, _model_id: Option<&str>, ) -> Response { (StatusCode::NOT_IMPLEMENTED).into_response() @@ -187,7 +2072,7 @@ impl RouterTrait for GrpcPDRouter { &self, _headers: Option<&HeaderMap>, _response_id: &str, - _params: &crate::protocols::spec::ResponsesGetParams, + _params: &ResponsesGetParams, ) -> Response { (StatusCode::NOT_IMPLEMENTED).into_response() } @@ -199,7 +2084,7 @@ impl RouterTrait for GrpcPDRouter { async fn route_embeddings( &self, _headers: Option<&HeaderMap>, - _body: &crate::protocols::spec::EmbeddingRequest, + _body: &EmbeddingRequest, _model_id: Option<&str>, ) -> Response { (StatusCode::NOT_IMPLEMENTED).into_response() @@ -208,7 +2093,7 @@ impl RouterTrait for GrpcPDRouter { async fn route_rerank( &self, _headers: Option<&HeaderMap>, - _body: &crate::protocols::spec::RerankRequest, + _body: &RerankRequest, _model_id: Option<&str>, ) -> Response { (StatusCode::NOT_IMPLEMENTED).into_response() diff --git a/sgl-router/src/routers/grpc/router.rs b/sgl-router/src/routers/grpc/router.rs index 0c63ff66e..553e3519e 100644 --- a/sgl-router/src/routers/grpc/router.rs +++ b/sgl-router/src/routers/grpc/router.rs @@ -15,45 +15,32 @@ use bytes::Bytes; use std::io; use tokio::sync::mpsc; use tokio_stream::wrappers::UnboundedReceiverStream; -use tracing::{debug, error, info, warn}; +use tracing::{debug, error, warn}; use crate::config::types::RetryConfig; use crate::core::{ConnectionMode, Worker, WorkerRegistry, WorkerType}; use crate::grpc_client::{proto, SglangSchedulerClient}; -use crate::metrics::RouterMetrics; use crate::policies::PolicyRegistry; -use crate::protocols::spec::ChatMessage; use crate::protocols::spec::{ ChatChoice, ChatCompletionMessage, ChatCompletionRequest, ChatCompletionResponse, - ChatCompletionStreamResponse, ChatMessageDelta, ChatStreamChoice, CompletionRequest, - EmbeddingRequest, FunctionCallDelta, FunctionCallResponse, GenerateRequest, RerankRequest, - ResponsesGetParams, ResponsesRequest, StringOrArray, Tool, ToolCall, ToolCallDelta, ToolChoice, - ToolChoiceValue, Usage, + ChatCompletionStreamResponse, ChatMessage, ChatMessageDelta, ChatStreamChoice, + CompletionRequest, EmbeddingRequest, FunctionCallDelta, FunctionCallResponse, GenerateRequest, + RerankRequest, ResponsesGetParams, ResponsesRequest, StringOrArray, ToolCall, ToolCallDelta, + ToolChoice, ToolChoiceValue, Usage, }; use crate::reasoning_parser::{ParserResult, ReasoningParserFactory}; -use crate::routers::RouterTrait; +use crate::routers::{grpc, RouterTrait}; use crate::server::AppContext; -use crate::tokenizer::chat_template::{ChatTemplateContentFormat, ChatTemplateParams}; -use crate::tokenizer::stop::{ - SequenceDecoderOutput, StopSequenceDecoder, StopSequenceDecoderBuilder, -}; +use crate::tokenizer::stop::{SequenceDecoderOutput, StopSequenceDecoder}; use crate::tokenizer::traits::Tokenizer; -use crate::tokenizer::HuggingFaceTokenizer; use crate::tool_parser::{StreamingParseResult, ToolParserFactory}; +use grpc::utils; use proto::generate_response::Response::{Chunk, Complete, Error}; -use serde_json::{json, Map, Value}; +use serde_json::{json, Value}; use std::time::{Instant, SystemTime, UNIX_EPOCH}; use tokio_stream::StreamExt; use uuid::Uuid; -// Data structures for processing -#[derive(Debug)] -pub struct ProcessedMessages { - pub text: String, - pub multimodal_inputs: Option, - pub stop_sequences: Option, -} - /// gRPC router implementation for SGLang #[derive(Clone)] #[allow(dead_code)] @@ -91,16 +78,6 @@ impl GrpcRouter { let worker_registry = ctx.worker_registry.clone(); let policy_registry = ctx.policy_registry.clone(); - let workers = worker_registry.get_workers_filtered( - None, - Some(WorkerType::Regular), - Some(ConnectionMode::Grpc { port: None }), - false, - ); - - RouterMetrics::set_active_workers(workers.len()); - info!("gRPC router found {} workers in registry", workers.len()); - Ok(GrpcRouter { worker_registry, policy_registry, @@ -125,56 +102,11 @@ impl GrpcRouter { model_id ); - // Step 1: Select worker (fail fast if no workers available) - let worker = match self.select_worker_for_request(model_id, None) { - Some(w) => w, - None => { - warn!("No available workers for model: {:?}", model_id); - return (StatusCode::SERVICE_UNAVAILABLE, "No available workers").into_response(); - } - }; + // Step 1: Filter tools if needed for allowed_tools or specific function + let body_ref = utils::filter_tools_for_request(body); - debug!("Selected worker: {}", worker.url()); - - // Step 2: Get gRPC client from worker - let client = match Self::get_grpc_client_from_worker(&worker).await { - Ok(client) => client, - Err(response) => return response, - }; - - // Step 3: Filter tools if needed for allowed_tools or specific function - // Only clone body if we need to modify tools - let mut body_with_filtered_tools; - let body_ref = match &body.tool_choice { - Some(ToolChoice::AllowedTools { tools: allowed, .. }) if body.tools.is_some() => { - body_with_filtered_tools = body.clone(); - let all_tools = body_with_filtered_tools.tools.as_ref().unwrap(); - let allowed_names: std::collections::HashSet<&str> = - allowed.iter().map(|t| t.name.as_str()).collect(); - let filtered_tools: Vec = all_tools - .iter() - .filter(|t| allowed_names.contains(t.function.name.as_str())) - .cloned() - .collect(); - body_with_filtered_tools.tools = Some(filtered_tools); - &body_with_filtered_tools - } - Some(ToolChoice::Function { function, .. }) if body.tools.is_some() => { - body_with_filtered_tools = body.clone(); - let all_tools = body_with_filtered_tools.tools.as_ref().unwrap(); - let filtered_tools: Vec = all_tools - .iter() - .filter(|t| t.function.name == function.name) - .cloned() - .collect(); - body_with_filtered_tools.tools = Some(filtered_tools); - &body_with_filtered_tools - } - _ => body, // No filtering needed, use original - }; - - // Step 4: Process messages and apply chat template - let processed_messages = match self.process_chat_messages(body_ref) { + // Step 2: Process messages and apply chat template + let processed_messages = match utils::process_chat_messages(&body_ref, &*self.tokenizer) { Ok(msgs) => msgs, Err(e) => { error!("Failed to process chat messages: {}", e); @@ -182,7 +114,7 @@ impl GrpcRouter { } }; - // Step 5: Tokenize the processed text + // Step 3: Tokenize the processed text let encoding = match self.tokenizer.encode(&processed_messages.text) { Ok(encoding) => encoding, Err(e) => { @@ -198,17 +130,35 @@ impl GrpcRouter { let token_ids = encoding.token_ids().to_vec(); debug!("Tokenized {} tokens from input", token_ids.len()); - // Step 6: Build tool constraints if needed + // Step 4: Build tool constraints if needed // body_ref already has filtered tools if needed let tool_call_constraint = body_ref.tools.as_ref().and_then(|tools| { - self.generate_tool_constraints(tools, &body.tool_choice, &body.model) + utils::generate_tool_constraints(tools, &body.tool_choice, &body.model) }); + // Step 5: Select worker + let worker = match self.select_worker_for_request(model_id, Some(&processed_messages.text)) + { + Some(w) => w, + None => { + warn!("No available workers for model: {:?}", model_id); + return (StatusCode::SERVICE_UNAVAILABLE, "No available workers").into_response(); + } + }; + + debug!("Selected worker: {}", worker.url()); + + // Step 6: Get gRPC client from worker + let client = match utils::get_grpc_client_from_worker(&worker).await { + Ok(client) => client, + Err(response) => return response, + }; + // Step 7: Build the base gRPC request (use body_ref with filtered tools if applicable) let request_id = format!("chatcmpl-{}", Uuid::new_v4()); let request = match client.build_generate_request( request_id, - body_ref, + &body_ref, processed_messages.text.clone(), token_ids, processed_messages.multimodal_inputs, @@ -265,7 +215,7 @@ impl GrpcRouter { debug!("Selected worker: {}", worker.url()); // Step 3: Get gRPC client from worker - let client = match Self::get_grpc_client_from_worker(&worker).await { + let client = match utils::get_grpc_client_from_worker(&worker).await { Ok(client) => client, Err(response) => return response, }; @@ -299,44 +249,12 @@ impl GrpcRouter { // Step 6: Handle streaming vs non-streaming if body.stream { - // TODO: Implement streaming support for generate endpoint - return ( - StatusCode::NOT_IMPLEMENTED, - "Streaming generate over gRPC is not supported yet", - ) - .into_response(); + self.handle_streaming_generate(client, request, body, request_id, weight_version) + .await + } else { + self.handle_non_streaming_generate(client, request, body, request_id, weight_version) + .await } - - self.handle_non_streaming_generate(client, request, body, request_id, weight_version) - .await - } - - /// Get gRPC client from worker, returning appropriate error response on failure - async fn get_grpc_client_from_worker( - worker: &Arc, - ) -> Result { - let client_arc = worker - .get_grpc_client() - .await - .map_err(|e| { - error!("Failed to get gRPC client from worker: {}", e); - ( - StatusCode::INTERNAL_SERVER_ERROR, - format!("Failed to get gRPC client: {}", e), - ) - .into_response() - })? - .ok_or_else(|| { - error!("Selected worker is not a gRPC worker"); - ( - StatusCode::INTERNAL_SERVER_ERROR, - "Selected worker is not configured for gRPC", - ) - .into_response() - })?; - - let client = client_arc.lock().await.clone(); - Ok(client) } /// Select a worker for the request @@ -375,412 +293,6 @@ impl GrpcRouter { Some(available[idx].clone()) } - /// Process chat messages and apply template - fn process_chat_messages( - &self, - request: &ChatCompletionRequest, - ) -> Result { - // Use the tokenizer's chat template - we require HuggingFace tokenizer for gRPC - let formatted_text = if let Some(hf_tokenizer) = self - .tokenizer - .as_any() - .downcast_ref::() - { - // Get content format and transform messages accordingly - let content_format = hf_tokenizer.chat_template_content_format(); - let mut transformed_messages = - Self::process_content_format(&request.messages, content_format)?; - - // Process tool call arguments in assistant messages - Self::process_tool_call_arguments(&mut transformed_messages)?; - - // Convert tools to JSON values for template processing - let tools_json: Option> = request - .tools - .as_ref() - .map(|tools| { - tools - .iter() - .map(serde_json::to_value) - .collect::, _>>() - }) - .transpose() - .map_err(|e| format!("Failed to serialize tools: {}", e))?; - - // Build template kwargs, merging reasoning_effort if present - let mut combined_template_kwargs = std::collections::HashMap::new(); - - // Add reasoning_effort if present (like Python does) - if let Some(reasoning_effort) = &request.reasoning_effort { - combined_template_kwargs.insert( - "reasoning_effort".to_string(), - Value::String(reasoning_effort.clone()), - ); - } - - // Add any additional template kwargs from request - if let Some(template_kwargs) = &request.chat_template_kwargs { - for (key, value) in template_kwargs { - combined_template_kwargs.insert(key.clone(), value.clone()); - } - } - - let final_template_kwargs = if combined_template_kwargs.is_empty() { - None - } else { - Some(&combined_template_kwargs) - }; - - let params = ChatTemplateParams { - add_generation_prompt: true, - continue_final_message: request.continue_final_message, - tools: tools_json.as_deref(), - template_kwargs: final_template_kwargs, - ..Default::default() - }; - - // Handle assistant prefix for continue_final_message - let assistant_prefix = if request.continue_final_message - && !transformed_messages.is_empty() - && transformed_messages - .last() - .and_then(|msg| msg.get("role")) - .and_then(|v| v.as_str()) - == Some("assistant") - { - // Pop the last message to handle it separately - let last_msg = transformed_messages.pop().unwrap(); - last_msg - .get("content") - .and_then(|v| v.as_str()) - .map(|s| s.to_string()) - } else { - None - }; - - // Apply chat template with the (now possibly shorter) list of messages - let rendered = hf_tokenizer - .apply_chat_template(&transformed_messages, params) - .map_err(|e| format!("Failed to apply chat template: {}", e))?; - - // Append assistant prefix if we have one - if let Some(prefix) = assistant_prefix { - format!("{}{}", rendered, prefix) - } else { - rendered - } - } else { - return Err( - "gRPC router requires HuggingFace tokenizer with chat template support".to_string(), - ); - }; - - // Placeholder for multimodal inputs - let multimodal_inputs = None; - - Ok(ProcessedMessages { - text: formatted_text, - multimodal_inputs, - stop_sequences: request.stop.clone(), - }) - } - - /// Process messages based on content format for ANY message type - fn process_content_format( - messages: &[ChatMessage], - content_format: ChatTemplateContentFormat, - ) -> Result, String> { - messages - .iter() - .map(|message| { - let mut message_json = serde_json::to_value(message) - .map_err(|e| format!("Failed to serialize message: {}", e))?; - - if let Some(obj) = message_json.as_object_mut() { - if let Some(content_value) = obj.get_mut("content") { - Self::transform_content_field(content_value, content_format); - } - } - - Ok(message_json) - }) - .collect() - } - - /// Transform a single content field based on content format - fn transform_content_field( - content_value: &mut Value, - content_format: ChatTemplateContentFormat, - ) { - let Some(content_array) = content_value.as_array() else { - return; // Not multimodal, keep as-is - }; - - match content_format { - ChatTemplateContentFormat::String => { - // Extract and join text parts only - let text_parts: Vec = content_array - .iter() - .filter_map(|part| { - part.as_object()? - .get("type")? - .as_str() - .filter(|&t| t == "text") - .and_then(|_| part.as_object()?.get("text")?.as_str()) - .map(String::from) - }) - .collect(); - - if !text_parts.is_empty() { - *content_value = Value::String(text_parts.join(" ")); - } - } - ChatTemplateContentFormat::OpenAI => { - // Replace media URLs with simple type placeholders - let processed_parts: Vec = content_array - .iter() - .map(|part| { - part.as_object() - .and_then(|obj| obj.get("type")?.as_str()) - .and_then(|type_str| match type_str { - "image_url" => Some(json!({"type": "image"})), - "video_url" => Some(json!({"type": "video"})), - "audio_url" => Some(json!({"type": "audio"})), - _ => None, - }) - .unwrap_or_else(|| part.clone()) - }) - .collect(); - - *content_value = Value::Array(processed_parts); - } - } - } - - /// Process tool call arguments in messages - /// Per Transformers docs, tool call arguments in assistant messages should be dicts - fn process_tool_call_arguments(messages: &mut [Value]) -> Result<(), String> { - for msg in messages { - // Early return if not assistant message - let role = msg.get("role").and_then(|v| v.as_str()); - if role != Some("assistant") { - continue; - } - - // Early return if no tool_calls - let Some(tool_calls) = msg.get_mut("tool_calls").and_then(|tc| tc.as_array_mut()) - else { - continue; - }; - - // Process each tool call's arguments - for call in tool_calls { - let Some(function) = call.get_mut("function") else { - continue; - }; - let Some(args) = function.get_mut("arguments") else { - continue; - }; - let Some(args_str) = args.as_str() else { - continue; - }; - - // Parse JSON string to object (like Python json.loads) - match serde_json::from_str::(args_str) { - Ok(parsed) => *args = parsed, - Err(e) => { - return Err(format!( - "Failed to parse tool call arguments as JSON: '{}'. Error: {}", - args_str, e - )) - } - } - } - } - Ok(()) - } - - /// Generate tool constraints for structured generation - /// Note: tools should already be filtered if needed (by allowed_tools or specific function) - fn generate_tool_constraints( - &self, - tools: &[Tool], - tool_choice: &Option, - _model: &str, - ) -> Option<(String, String)> { - let choice = tool_choice.as_ref()?; - - match choice { - // Specific function: Return parameters schema directly - // tools should already be filtered to contain only the specific function - ToolChoice::Function { .. } => { - if tools.is_empty() { - return None; - } - let tool = &tools[0]; - - // Return the tool's parameters schema directly (not wrapped in array) - let params_schema = serde_json::to_string(&tool.function.parameters).ok()?; - Some(("json_schema".to_string(), params_schema)) - } - - // Required: Array of tool calls with minItems: 1 - ToolChoice::Value(ToolChoiceValue::Required) => { - let schema = self.build_required_array_schema(tools)?; - Some(("json_schema".to_string(), schema)) - } - - // AllowedTools with required mode: tools are already filtered - ToolChoice::AllowedTools { mode, .. } => { - if mode == "required" { - if tools.is_empty() { - return None; - } - let schema = self.build_required_array_schema(tools)?; - Some(("json_schema".to_string(), schema)) - } else { - // "auto" mode - no constraint needed - None - } - } - - // "auto" or "none" - no constraint - _ => None, - } - } - - /// Build JSON schema for required tool calls (array with minItems: 1) - /// Includes $defs consolidation from all tools (matching Python's behavior) - fn build_required_array_schema(&self, tools: &[Tool]) -> Option { - // Build anyOf schemas for each tool - let mut any_of_schemas = Vec::new(); - for tool in tools { - let tool_schema = json!({ - "properties": { - "name": { - "type": "string", - "enum": [tool.function.name] - }, - "parameters": tool.function.parameters - }, - "required": ["name", "parameters"] - }); - any_of_schemas.push(tool_schema); - } - - // Consolidate $defs from all tools (matching Python's _get_tool_schema_defs) - let mut all_defs: HashMap = HashMap::new(); - for tool in tools { - if let Value::Object(params) = &tool.function.parameters { - if let Some(Value::Object(defs)) = params.get("$defs") { - for (def_name, def_schema) in defs { - if let Some(existing) = all_defs.get(def_name) { - // Check for conflicts - if existing != def_schema { - error!( - "Tool definition '{}' has multiple schemas, which is not supported", - def_name - ); - return None; - } - } else { - all_defs.insert(def_name.clone(), def_schema.clone()); - } - } - } - } - } - - // Build the full array schema - let mut array_schema = json!({ - "type": "array", - "minItems": 1, - "items": { - "type": "object", - "anyOf": any_of_schemas - } - }); - - // Add $defs if any were found (matching Python's behavior) - if !all_defs.is_empty() { - if let Value::Object(ref mut schema_obj) = array_schema { - let defs_value = - Value::Object(all_defs.into_iter().collect::>()); - schema_obj.insert("$defs".to_string(), defs_value); - } - } - - serde_json::to_string(&array_schema).ok() - } - - /// Parse tool calls from JSON schema constrained response - fn parse_json_schema_response( - &self, - processed_text: &str, - tool_choice: &Option, - ) -> (Option>, String) { - match tool_choice { - Some(ToolChoice::Function { function, .. }) => { - // Specific function: Parse parameters directly - match serde_json::from_str::(processed_text) { - Ok(params) => { - let tool_call = ToolCall { - id: format!("call_{}", uuid::Uuid::new_v4()), - tool_type: "function".to_string(), - function: FunctionCallResponse { - name: function.name.clone(), - arguments: Some( - serde_json::to_string(¶ms) - .unwrap_or_else(|_| "{}".to_string()), - ), - }, - }; - (Some(vec![tool_call]), String::new()) - } - Err(e) => { - error!("Failed to parse specific function parameters: {}", e); - (None, processed_text.to_string()) - } - } - } - Some(ToolChoice::Value(ToolChoiceValue::Required)) - | Some(ToolChoice::AllowedTools { .. }) => { - // Required mode: Parse array of tool calls - match serde_json::from_str::>(processed_text) { - Ok(parsed_array) => { - let spec_tool_calls: Vec = parsed_array - .into_iter() - .enumerate() - .filter_map(|(i, item)| { - let obj = item.as_object()?; - let name = obj.get("name")?.as_str()?.to_string(); - let parameters = obj.get("parameters")?; - - Some(ToolCall { - id: format!("call_{}_{}", i, uuid::Uuid::new_v4()), - tool_type: "function".to_string(), - function: FunctionCallResponse { - name, - arguments: Some( - serde_json::to_string(parameters) - .unwrap_or_else(|_| "{}".to_string()), - ), - }, - }) - }) - .collect(); - (Some(spec_tool_calls), String::new()) - } - Err(e) => { - error!("Failed to parse required tool call array: {}", e); - (None, processed_text.to_string()) - } - } - } - _ => (None, processed_text.to_string()), - } - } - /// Parse tool calls using model-specific parser async fn parse_tool_calls( &self, @@ -895,48 +407,6 @@ impl GrpcRouter { (StatusCode::INTERNAL_SERVER_ERROR, message).into_response() } - /// Create a StopSequenceDecoder from stop parameters - fn create_stop_decoder( - &self, - stop: Option<&StringOrArray>, - stop_token_ids: Option<&Vec>, - skip_special_tokens: bool, - no_stop_trim: bool, - ) -> StopSequenceDecoder { - // Extract stop sequences - let stop_sequences: Vec = match stop { - Some(StringOrArray::String(s)) => vec![s.clone()], - Some(StringOrArray::Array(arr)) => arr.clone(), - None => vec![], - }; - - // Build stop sequence decoder - let mut builder = StopSequenceDecoderBuilder::new(self.tokenizer.clone()) - .skip_special_tokens(skip_special_tokens); - - // Add stop sequences (visible if no_stop_trim is true, hidden otherwise) - for seq in stop_sequences { - builder = if no_stop_trim { - builder.visible_stop_sequence(seq) - } else { - builder.stop_sequence(seq) - }; - } - - // Add stop token IDs (visible if no_stop_trim is true, hidden otherwise) - if let Some(token_ids) = stop_token_ids { - for &token_id in token_ids { - builder = if no_stop_trim { - builder.visible_stop_token(token_id) - } else { - builder.stop_token(token_id) - }; - } - } - - builder.build() - } - /// Count the number of tool calls in the request message history /// This is used for KimiK2 format which needs globally unique indices fn get_history_tool_calls_count(request: &ChatCompletionRequest) -> usize { @@ -1354,7 +824,8 @@ impl GrpcRouter { // Create stop decoder let (stop, stop_token_ids, skip_special_tokens, no_stop_trim) = stop_params; - let mut stop_decoder = router.create_stop_decoder( + let mut stop_decoder = utils::create_stop_decoder( + &router.tokenizer, stop.as_ref(), stop_token_ids.as_ref(), skip_special_tokens, @@ -1678,7 +1149,8 @@ impl GrpcRouter { request: proto::GenerateRequest, original_request: &ChatCompletionRequest, ) -> Response { - let mut stop_decoder = self.create_stop_decoder( + let mut stop_decoder = utils::create_stop_decoder( + &self.tokenizer, original_request.stop.as_ref(), original_request.stop_token_ids.as_ref(), original_request.skip_special_tokens, @@ -1686,42 +1158,17 @@ impl GrpcRouter { ); // Start generation - let mut stream = match client.generate(request).await { + let stream = match client.generate(request).await { Ok(s) => s, Err(e) => { return Self::internal_error_message(format!("Failed to start generation: {}", e)) } }; - // Collect all responses (for n>1 support) - let mut all_responses = Vec::new(); - while let Some(response) = stream.next().await { - match response { - Ok(gen_response) => match gen_response.response { - Some(Complete(complete)) => { - all_responses.push(complete); - } - Some(Error(err)) => { - return Self::internal_error_message(format!( - "Generation failed: {}", - err.message - )); - } - Some(Chunk(_)) => { - return Self::internal_error_static( - "Unexpected chunk response for non-streaming request", - ) - } - None => return Self::internal_error_static("Empty response from server"), - }, - Err(e) => { - return Self::internal_error_message(format!( - "Failed to get GenerateResponse: {}", - e - )) - } - } - } + let all_responses = match utils::collect_stream_responses(stream, "Regular").await { + Ok(responses) => responses, + Err(err_response) => return err_response, + }; if all_responses.is_empty() { return Self::internal_error_static("No responses from server"); @@ -1793,115 +1240,262 @@ impl GrpcRouter { ) -> Response { let start_time = Instant::now(); - let mut stream = match client.generate(request).await { + let stream = match client.generate(request).await { Ok(stream) => stream, Err(e) => { return Self::internal_error_message(format!("Failed to start generation: {}", e)) } }; - let mut final_completion: Option = None; - - while let Some(result) = stream.next().await { - match result { - Ok(gen_response) => match gen_response.response { - Some(Complete(complete)) => { - final_completion = Some(complete); - break; - } - Some(Error(err)) => { - return Self::internal_error_message(format!( - "Generation failed: {}", - err.message - )); - } - Some(Chunk(_)) | None => continue, - }, - Err(e) => { - return Self::internal_error_message(format!( - "Failed to receive generate response: {}", - e - )) - } - } - } - - let mut complete = match final_completion { - Some(c) => c, - None => { - return Self::internal_error_static("No completion received from scheduler"); - } + // Collect all responses using utils helper + let responses = match utils::collect_stream_responses(stream, "Generate").await { + Ok(responses) => responses, + Err(error_response) => return error_response, }; + if responses.is_empty() { + return Self::internal_error_static("No completion received from scheduler"); + } + // Create stop decoder from sampling params let params = original_request.sampling_params.as_ref(); - let mut stop_decoder = self.create_stop_decoder( + let mut stop_decoder = utils::create_stop_decoder( + &self.tokenizer, params.and_then(|p| p.stop.as_ref()), params.and_then(|p| p.stop_token_ids.as_ref()), params.and_then(|p| p.skip_special_tokens).unwrap_or(true), params.and_then(|p| p.no_stop_trim).unwrap_or(false), ); - // Process tokens through stop decoder - let outputs = match stop_decoder.process_tokens(&complete.output_ids) { - Ok(outputs) => outputs, + // Process each completion + let mut result_array = Vec::new(); + for mut complete in responses { + stop_decoder.reset(); + + // Process tokens through stop decoder + let outputs = match stop_decoder.process_tokens(&complete.output_ids) { + Ok(outputs) => outputs, + Err(e) => { + return Self::internal_error_message(format!("Failed to process tokens: {}", e)) + } + }; + + // Accumulate text with early breaks + let mut decoded_text = String::new(); + for output in outputs { + match output { + SequenceDecoderOutput::Text(t) => decoded_text.push_str(&t), + SequenceDecoderOutput::StoppedWithText(t) => { + decoded_text.push_str(&t); + break; + } + SequenceDecoderOutput::Stopped => break, + SequenceDecoderOutput::Held => {} + } + } + + // Flush remaining text + if let SequenceDecoderOutput::Text(t) = stop_decoder.flush() { + decoded_text.push_str(&t); + } + + let output_ids = std::mem::take(&mut complete.output_ids); + let finish_reason = std::mem::take(&mut complete.finish_reason); + + // Build base meta_info using json! macro + let mut meta_info = json!({ + "id": request_id.clone(), + "finish_reason": finish_reason, + "prompt_tokens": complete.prompt_tokens, + "weight_version": weight_version.clone(), + "completion_tokens": complete.completion_tokens, + "cached_tokens": complete.cached_tokens, + "e2e_latency": start_time.elapsed().as_secs_f64(), + }); + + let meta_obj = meta_info.as_object_mut().unwrap(); + + // Add matched_stop if present + if let Some(matched) = complete.matched_stop.take() { + use proto::generate_complete::MatchedStop; + let matched_value = match matched { + MatchedStop::MatchedTokenId(id) => json!(id), + MatchedStop::MatchedStopStr(s) => json!(s), + }; + meta_obj.insert("matched_stop".to_string(), matched_value); + } + + result_array.push(json!({ + "text": decoded_text, + "output_ids": output_ids, + "meta_info": meta_info, + })); + } + + Json(result_array).into_response() + } + + /// Submit request and handle streaming response for the `/generate` endpoint + async fn handle_streaming_generate( + &self, + mut client: SglangSchedulerClient, + request: proto::GenerateRequest, + original_request: &GenerateRequest, + request_id: String, + weight_version: String, + ) -> Response { + let tokenizer = self.tokenizer.clone(); + let return_logprob = original_request.return_logprob; + + // Create channel for SSE streaming + let (tx, rx) = + tokio::sync::mpsc::unbounded_channel::>(); + + // Start the stream + let stream = match client.generate(request).await { + Ok(stream) => stream, Err(e) => { - return Self::internal_error_message(format!("Failed to process tokens: {}", e)) + return Self::internal_error_message(format!("Failed to start generation: {}", e)) } }; - // Accumulate text with early breaks - let mut decoded_text = String::new(); - for output in outputs { - match output { - SequenceDecoderOutput::Text(t) => decoded_text.push_str(&t), - SequenceDecoderOutput::StoppedWithText(t) => { - decoded_text.push_str(&t); - break; + // Spawn async task to process stream + tokio::spawn(async move { + let result = Self::process_generate_streaming( + tokenizer, + stream, + request_id, + weight_version, + return_logprob, + &tx, + ) + .await; + + if let Err(e) = result { + let error_chunk = format!("data: {{\"error\": \"{}\"}}\n\n", e); + let _ = tx.send(Ok(bytes::Bytes::from(error_chunk))); + } + + // Send [DONE] marker + let _ = tx.send(Ok(bytes::Bytes::from("data: [DONE]\n\n"))); + }); + + // Create SSE response stream + let body_stream = tokio_stream::wrappers::UnboundedReceiverStream::new(rx); + Response::builder() + .status(StatusCode::OK) + .header("Content-Type", "text/event-stream") + .header("Cache-Control", "no-cache") + .header("Connection", "keep-alive") + .body(axum::body::Body::from_stream(body_stream)) + .unwrap() + } + + /// Process streaming chunks for generate endpoint + async fn process_generate_streaming( + tokenizer: Arc, + mut stream: impl tokio_stream::Stream> + + Unpin, + request_id: String, + weight_version: String, + _include_logprobs: bool, + tx: &tokio::sync::mpsc::UnboundedSender>, + ) -> Result<(), String> { + use proto::generate_response::Response::{Chunk, Complete, Error}; + use std::time::Instant; + use tokio_stream::StreamExt; + + let start_time = Instant::now(); + + // Track state per index for n>1 case + use std::collections::HashMap; + let mut accumulated_texts: HashMap = HashMap::new(); + let mut completion_tokens_map: HashMap = HashMap::new(); + + while let Some(response) = stream.next().await { + let gen_response = response.map_err(|e| format!("Stream error: {}", e))?; + + match gen_response.response { + Some(Chunk(chunk)) => { + let index = chunk.index; + + // Update completion tokens for this index + let completion_tokens = completion_tokens_map.entry(index).or_insert(0); + *completion_tokens += chunk.token_ids.len() as u32; + + // Decode tokens to text (skip_special_tokens=true to handle newlines correctly) + let chunk_text = tokenizer.decode(&chunk.token_ids, true).unwrap_or_default(); + + // Accumulate text for this index + let accumulated_text = accumulated_texts.entry(index).or_default(); + accumulated_text.push_str(&chunk_text); + + // Generate unique ID per index + let index_id = format!("{}-{}", request_id, index); + + // Build streaming response chunk (SGLang format) + let chunk_response = serde_json::json!({ + "text": accumulated_text.clone(), + "output_ids": chunk.token_ids, + "meta_info": { + "id": index_id, + "finish_reason": null, + "prompt_tokens": chunk.prompt_tokens, + "weight_version": weight_version, + "completion_tokens": *completion_tokens, + "cached_tokens": chunk.cached_tokens + }, + "index": index + }); + + let sse_chunk = format!( + "data: {}\n\n", + serde_json::to_string(&chunk_response).unwrap() + ); + tx.send(Ok(bytes::Bytes::from(sse_chunk))) + .map_err(|_| "Failed to send chunk".to_string())?; } - SequenceDecoderOutput::Stopped => break, - SequenceDecoderOutput::Held => {} + Some(Complete(complete)) => { + let index = complete.index; + let accumulated_text = + accumulated_texts.get(&index).cloned().unwrap_or_default(); + let completion_tokens = *completion_tokens_map.get(&index).unwrap_or(&0); + let index_id = format!("{}-{}", request_id, index); + let e2e_latency = start_time.elapsed().as_secs_f64(); + + // Send final chunk with finish_reason (no new tokens in Complete, they were already sent in Chunks) + let finish_response = serde_json::json!({ + "text": accumulated_text, + "output_ids": complete.output_ids[complete.output_ids.len().saturating_sub(1)..].to_vec(), + "meta_info": { + "id": index_id, + "finish_reason": complete.finish_reason, + "prompt_tokens": complete.prompt_tokens, + "weight_version": weight_version, + "completion_tokens": completion_tokens, + "cached_tokens": complete.cached_tokens, + "e2e_latency": e2e_latency + }, + "index": index + }); + + let sse_chunk = format!( + "data: {}\n\n", + serde_json::to_string(&finish_response).unwrap() + ); + tx.send(Ok(bytes::Bytes::from(sse_chunk))) + .map_err(|_| "Failed to send finish chunk".to_string())?; + + // Continue to process all completions if n>1 + } + Some(Error(error)) => { + return Err(error.message); + } + None => continue, } } - // Flush remaining text - if let SequenceDecoderOutput::Text(t) = stop_decoder.flush() { - decoded_text.push_str(&t); - } - - let output_ids = std::mem::take(&mut complete.output_ids); - let finish_reason = std::mem::take(&mut complete.finish_reason); - - // Build base meta_info using json! macro - let mut meta_info = json!({ - "finish_reason": finish_reason, - "prompt_tokens": complete.prompt_tokens, - "completion_tokens": complete.completion_tokens, - "cached_tokens": complete.cached_tokens, - "id": request_id, - "weight_version": weight_version, - "e2e_latency": start_time.elapsed().as_secs_f64(), - }); - - let meta_obj = meta_info.as_object_mut().unwrap(); - - // Add matched_stop if present - if let Some(matched) = complete.matched_stop.take() { - use proto::generate_complete::MatchedStop; - let matched_value = match matched { - MatchedStop::MatchedTokenId(id) => json!(id), - MatchedStop::MatchedStopStr(s) => json!(s), - }; - meta_obj.insert("matched_stop".to_string(), matched_value); - } - - let response_body = json!({ - "text": decoded_text, - "output_ids": output_ids, - "meta_info": meta_info, - }); - - Json(response_body).into_response() + Ok(()) } /// Convert proto LogProbs to OpenAI ChatLogProbs format @@ -2036,28 +1630,28 @@ impl GrpcRouter { } // Step 2: Handle tool call parsing - let mut tool_calls: Option> = None; + let mut tool_calls: Option> = None; // Check if tool calls should be processed let tool_choice_enabled = !matches!( &original_request.tool_choice, - Some(ToolChoice::Value( - crate::protocols::spec::ToolChoiceValue::None - )) + Some(ToolChoice::Value(ToolChoiceValue::None)) ); if tool_choice_enabled && original_request.tools.is_some() { // Check if JSON schema constraint was used (specific function or required mode) let used_json_schema = match &original_request.tool_choice { Some(ToolChoice::Function { .. }) => true, - Some(ToolChoice::Value(crate::protocols::spec::ToolChoiceValue::Required)) => true, + Some(ToolChoice::Value(ToolChoiceValue::Required)) => true, Some(ToolChoice::AllowedTools { mode, .. }) => mode == "required", _ => false, }; if used_json_schema { - (tool_calls, processed_text) = - self.parse_json_schema_response(&processed_text, &original_request.tool_choice); + (tool_calls, processed_text) = utils::parse_json_schema_response( + &processed_text, + &original_request.tool_choice, + ); } else { (tool_calls, processed_text) = self .parse_tool_calls( @@ -2081,11 +1675,11 @@ impl GrpcRouter { // Extract matched_stop information from proto let matched_stop = match &complete.matched_stop { - Some(proto::generate_complete::MatchedStop::MatchedTokenId(token_id)) => Some( - serde_json::Value::Number(serde_json::Number::from(*token_id)), - ), + Some(proto::generate_complete::MatchedStop::MatchedTokenId(token_id)) => { + Some(Value::Number(serde_json::Number::from(*token_id))) + } Some(proto::generate_complete::MatchedStop::MatchedStopStr(stop_str)) => { - Some(serde_json::Value::String(stop_str.clone())) + Some(Value::String(stop_str.clone())) } None => None, }; @@ -2239,240 +1833,3 @@ impl RouterTrait for GrpcRouter { "grpc" } } - -#[cfg(test)] -mod tests { - use super::*; - use crate::protocols::spec::{ChatMessage, ContentPart, ImageUrl, UserMessageContent}; - use crate::tokenizer::chat_template::ChatTemplateContentFormat; - use serde_json::json; - - #[test] - fn test_transform_messages_string_format() { - let messages = vec![ChatMessage::User { - role: "user".to_string(), - content: UserMessageContent::Parts(vec![ - ContentPart::Text { - text: "Hello".to_string(), - }, - ContentPart::ImageUrl { - image_url: ImageUrl { - url: "https://example.com/image.jpg".to_string(), - detail: None, - }, - }, - ContentPart::Text { - text: "World".to_string(), - }, - ]), - name: None, - }]; - - let result = - GrpcRouter::process_content_format(&messages, ChatTemplateContentFormat::String) - .unwrap(); - - assert_eq!(result.len(), 1); - let transformed_message = &result[0]; - - // Should flatten multimodal content to text only - assert_eq!( - transformed_message["content"].as_str().unwrap(), - "Hello World" - ); - assert_eq!(transformed_message["role"].as_str().unwrap(), "user"); - } - - #[test] - fn test_transform_messages_openai_format() { - let messages = vec![ChatMessage::User { - role: "user".to_string(), - content: UserMessageContent::Parts(vec![ - ContentPart::Text { - text: "Describe this image:".to_string(), - }, - ContentPart::ImageUrl { - image_url: ImageUrl { - url: "https://example.com/image.jpg".to_string(), - detail: Some("high".to_string()), - }, - }, - ]), - name: None, - }]; - - let result = - GrpcRouter::process_content_format(&messages, ChatTemplateContentFormat::OpenAI) - .unwrap(); - - assert_eq!(result.len(), 1); - let transformed_message = &result[0]; - - // Should replace media URLs with simple type placeholders - let content_array = transformed_message["content"].as_array().unwrap(); - assert_eq!(content_array.len(), 2); - - // Text part should remain unchanged - assert_eq!(content_array[0]["type"], "text"); - assert_eq!(content_array[0]["text"], "Describe this image:"); - - // Image part should be replaced with simple type placeholder - assert_eq!(content_array[1], json!({"type": "image"})); - } - - #[test] - fn test_transform_messages_simple_string_content() { - let messages = vec![ChatMessage::User { - role: "user".to_string(), - content: UserMessageContent::Text("Simple text message".to_string()), - name: None, - }]; - - let result = - GrpcRouter::process_content_format(&messages, ChatTemplateContentFormat::String) - .unwrap(); - - assert_eq!(result.len(), 1); - let transformed_message = &result[0]; - - // Simple string content should remain unchanged - assert_eq!( - transformed_message["content"].as_str().unwrap(), - "Simple text message" - ); - } - - #[test] - fn test_transform_messages_assistant_message() { - let messages = vec![ChatMessage::Assistant { - role: "assistant".to_string(), - content: Some("Assistant response".to_string()), - name: None, - tool_calls: None, - reasoning_content: None, - }]; - - let result = - GrpcRouter::process_content_format(&messages, ChatTemplateContentFormat::String) - .unwrap(); - - assert_eq!(result.len(), 1); - let transformed_message = &result[0]; - - assert_eq!(transformed_message["role"].as_str().unwrap(), "assistant"); - assert_eq!( - transformed_message["content"].as_str().unwrap(), - "Assistant response" - ); - } - - #[test] - fn test_transform_messages_multiple_messages() { - let messages = vec![ - ChatMessage::System { - role: "system".to_string(), - content: "System prompt".to_string(), - name: None, - }, - ChatMessage::User { - role: "user".to_string(), - content: UserMessageContent::Parts(vec![ - ContentPart::Text { - text: "User message".to_string(), - }, - ContentPart::ImageUrl { - image_url: ImageUrl { - url: "https://example.com/image.jpg".to_string(), - detail: None, - }, - }, - ]), - name: None, - }, - ]; - - let result = - GrpcRouter::process_content_format(&messages, ChatTemplateContentFormat::String) - .unwrap(); - - assert_eq!(result.len(), 2); - - // System message should remain unchanged - assert_eq!(result[0]["role"].as_str().unwrap(), "system"); - assert_eq!(result[0]["content"].as_str().unwrap(), "System prompt"); - - // User message should be flattened to text only - assert_eq!(result[1]["role"].as_str().unwrap(), "user"); - assert_eq!(result[1]["content"].as_str().unwrap(), "User message"); - } - - #[test] - fn test_transform_messages_empty_text_parts() { - let messages = vec![ChatMessage::User { - role: "user".to_string(), - content: UserMessageContent::Parts(vec![ContentPart::ImageUrl { - image_url: ImageUrl { - url: "https://example.com/image.jpg".to_string(), - detail: None, - }, - }]), - name: None, - }]; - - let result = - GrpcRouter::process_content_format(&messages, ChatTemplateContentFormat::String) - .unwrap(); - - assert_eq!(result.len(), 1); - let transformed_message = &result[0]; - - // Should keep original multimodal content when no text parts exist - assert!(transformed_message["content"].is_array()); - } - - #[test] - fn test_transform_messages_mixed_content_types() { - let messages = vec![ - ChatMessage::User { - role: "user".to_string(), - content: UserMessageContent::Text("Plain text".to_string()), - name: None, - }, - ChatMessage::User { - role: "user".to_string(), - content: UserMessageContent::Parts(vec![ - ContentPart::Text { - text: "With image".to_string(), - }, - ContentPart::ImageUrl { - image_url: ImageUrl { - url: "https://example.com/image.jpg".to_string(), - detail: Some("low".to_string()), - }, - }, - ]), - name: None, - }, - ]; - - let result_string = - GrpcRouter::process_content_format(&messages, ChatTemplateContentFormat::String) - .unwrap(); - - assert_eq!(result_string.len(), 2); - assert_eq!(result_string[0]["content"].as_str().unwrap(), "Plain text"); - assert_eq!(result_string[1]["content"].as_str().unwrap(), "With image"); - - let result_openai = - GrpcRouter::process_content_format(&messages, ChatTemplateContentFormat::OpenAI) - .unwrap(); - - assert_eq!(result_openai.len(), 2); - assert_eq!(result_openai[0]["content"].as_str().unwrap(), "Plain text"); - - let content_array = result_openai[1]["content"].as_array().unwrap(); - assert_eq!(content_array.len(), 2); - assert_eq!(content_array[0]["type"], "text"); - assert_eq!(content_array[1], json!({"type": "image"})); - } -} diff --git a/sgl-router/src/routers/grpc/utils.rs b/sgl-router/src/routers/grpc/utils.rs new file mode 100644 index 000000000..3f97c3ed3 --- /dev/null +++ b/sgl-router/src/routers/grpc/utils.rs @@ -0,0 +1,843 @@ +//! Shared utilities for gRPC routers + +use super::ProcessedMessages; +use crate::core::Worker; +use crate::grpc_client::{proto, SglangSchedulerClient}; +use crate::protocols::spec::{ + ChatCompletionRequest, ChatMessage, FunctionCallResponse, StringOrArray, Tool, ToolCall, + ToolChoice, ToolChoiceValue, +}; +use crate::tokenizer::chat_template::{ChatTemplateContentFormat, ChatTemplateParams}; +use crate::tokenizer::traits::Tokenizer; +use crate::tokenizer::HuggingFaceTokenizer; +pub use crate::tokenizer::StopSequenceDecoder; +use axum::{ + http::StatusCode, + response::{IntoResponse, Response}, +}; +use futures::StreamExt; +use serde_json::{json, Map, Value}; +use std::collections::HashMap; +use std::sync::Arc; +use tonic::codec::Streaming; +use tracing::{debug, error}; +use uuid::Uuid; + +/// Get gRPC client from worker, returning appropriate error response on failure +pub async fn get_grpc_client_from_worker( + worker: &Arc, +) -> Result { + let client_arc = worker + .get_grpc_client() + .await + .map_err(|e| { + error!("Failed to get gRPC client from worker: {}", e); + ( + StatusCode::INTERNAL_SERVER_ERROR, + format!("Failed to get gRPC client: {}", e), + ) + .into_response() + })? + .ok_or_else(|| { + error!("Selected worker is not a gRPC worker"); + ( + StatusCode::INTERNAL_SERVER_ERROR, + "Selected worker is not configured for gRPC", + ) + .into_response() + })?; + + let client = client_arc.lock().await.clone(); + Ok(client) +} + +/// Process tool call arguments in messages +/// Per Transformers docs, tool call arguments in assistant messages should be dicts +pub fn process_tool_call_arguments(messages: &mut [Value]) -> Result<(), String> { + for msg in messages { + // Early return if not assistant message + let role = msg.get("role").and_then(|v| v.as_str()); + if role != Some("assistant") { + continue; + } + + // Early return if no tool_calls + let Some(tool_calls) = msg.get_mut("tool_calls").and_then(|tc| tc.as_array_mut()) else { + continue; + }; + + // Process each tool call's arguments + for call in tool_calls { + let Some(function) = call.get_mut("function") else { + continue; + }; + let Some(args) = function.get_mut("arguments") else { + continue; + }; + let Some(args_str) = args.as_str() else { + continue; + }; + + // Parse JSON string to object (like Python json.loads) + match serde_json::from_str::(args_str) { + Ok(parsed) => *args = parsed, + Err(e) => { + return Err(format!( + "Failed to parse tool call arguments as JSON: '{}'. Error: {}", + args_str, e + )) + } + } + } + } + Ok(()) +} + +/// Process messages based on content format for ANY message type +pub fn process_content_format( + messages: &[ChatMessage], + content_format: ChatTemplateContentFormat, +) -> Result, String> { + messages + .iter() + .map(|message| { + let mut message_json = serde_json::to_value(message) + .map_err(|e| format!("Failed to serialize message: {}", e))?; + + if let Some(obj) = message_json.as_object_mut() { + if let Some(content_value) = obj.get_mut("content") { + transform_content_field(content_value, content_format); + } + } + + Ok(message_json) + }) + .collect() +} + +/// Transform a single content field based on content format +pub fn transform_content_field( + content_value: &mut Value, + content_format: ChatTemplateContentFormat, +) { + let Some(content_array) = content_value.as_array() else { + return; // Not multimodal, keep as-is + }; + + match content_format { + ChatTemplateContentFormat::String => { + // Extract and join text parts only + let text_parts: Vec = content_array + .iter() + .filter_map(|part| { + part.as_object()? + .get("type")? + .as_str() + .filter(|&t| t == "text") + .and_then(|_| part.as_object()?.get("text")?.as_str()) + .map(String::from) + }) + .collect(); + + if !text_parts.is_empty() { + *content_value = Value::String(text_parts.join(" ")); + } + } + ChatTemplateContentFormat::OpenAI => { + // Replace media URLs with simple type placeholders + let processed_parts: Vec = content_array + .iter() + .map(|part| { + part.as_object() + .and_then(|obj| obj.get("type")?.as_str()) + .and_then(|type_str| match type_str { + "image_url" => Some(json!({"type": "image"})), + "video_url" => Some(json!({"type": "video"})), + "audio_url" => Some(json!({"type": "audio"})), + _ => None, + }) + .unwrap_or_else(|| part.clone()) + }) + .collect(); + + *content_value = Value::Array(processed_parts); + } + } +} + +/// Generate tool constraints for structured generation +/// Note: tools should already be filtered if needed (by allowed_tools or specific function) +pub fn generate_tool_constraints( + tools: &[Tool], + tool_choice: &Option, + _model: &str, +) -> Option<(String, String)> { + let choice = tool_choice.as_ref()?; + + match choice { + // Specific function: Return parameters schema directly + // tools should already be filtered to contain only the specific function + ToolChoice::Function { .. } => { + if tools.is_empty() { + return None; + } + let tool = &tools[0]; + + // Return the tool's parameters schema directly (not wrapped in array) + let params_schema = serde_json::to_string(&tool.function.parameters).ok()?; + Some(("json_schema".to_string(), params_schema)) + } + + // Required: Array of tool calls with minItems: 1 + ToolChoice::Value(ToolChoiceValue::Required) => { + let schema = build_required_array_schema(tools)?; + Some(("json_schema".to_string(), schema)) + } + + // AllowedTools with required mode: tools are already filtered + ToolChoice::AllowedTools { mode, .. } => { + if mode == "required" { + if tools.is_empty() { + return None; + } + let schema = build_required_array_schema(tools)?; + Some(("json_schema".to_string(), schema)) + } else { + // "auto" mode - no constraint needed + None + } + } + + // "auto" or "none" - no constraint + _ => None, + } +} + +/// Build JSON schema for required tool calls (array with minItems: 1) +/// Includes $defs consolidation from all tools (matching Python's behavior) +pub fn build_required_array_schema(tools: &[Tool]) -> Option { + // Build anyOf schemas for each tool + let mut any_of_schemas = Vec::new(); + for tool in tools { + let tool_schema = json!({ + "properties": { + "name": { + "type": "string", + "enum": [tool.function.name] + }, + "parameters": tool.function.parameters + }, + "required": ["name", "parameters"] + }); + any_of_schemas.push(tool_schema); + } + + // Consolidate $defs from all tools (matching Python's _get_tool_schema_defs) + let mut all_defs: HashMap = HashMap::new(); + for tool in tools { + if let Value::Object(params) = &tool.function.parameters { + if let Some(Value::Object(defs)) = params.get("$defs") { + for (def_name, def_schema) in defs { + if let Some(existing) = all_defs.get(def_name) { + // Check for conflicts + if existing != def_schema { + error!( + "Tool definition '{}' has multiple schemas, which is not supported", + def_name + ); + return None; + } + } else { + all_defs.insert(def_name.clone(), def_schema.clone()); + } + } + } + } + } + + // Build the full array schema + let mut array_schema = json!({ + "type": "array", + "minItems": 1, + "items": { + "type": "object", + "anyOf": any_of_schemas + } + }); + + // Add $defs if any were found (matching Python's behavior) + if !all_defs.is_empty() { + if let Value::Object(ref mut schema_obj) = array_schema { + let defs_value = Value::Object(all_defs.into_iter().collect::>()); + schema_obj.insert("$defs".to_string(), defs_value); + } + } + + serde_json::to_string(&array_schema).ok() +} + +/// Filter tools based on tool_choice (shared by both routers) +/// Returns a reference to the original body if no filtering needed, +/// otherwise returns a cloned and filtered body +pub fn filter_tools_for_request( + body: &ChatCompletionRequest, +) -> std::borrow::Cow<'_, ChatCompletionRequest> { + match &body.tool_choice { + Some(ToolChoice::AllowedTools { tools: allowed, .. }) if body.tools.is_some() => { + let mut filtered_body = body.clone(); + let all_tools = filtered_body.tools.as_ref().unwrap(); + let allowed_names: std::collections::HashSet<&str> = + allowed.iter().map(|t| t.name.as_str()).collect(); + let filtered_tools: Vec = all_tools + .iter() + .filter(|t| allowed_names.contains(t.function.name.as_str())) + .cloned() + .collect(); + filtered_body.tools = Some(filtered_tools); + std::borrow::Cow::Owned(filtered_body) + } + Some(ToolChoice::Function { function, .. }) if body.tools.is_some() => { + let mut filtered_body = body.clone(); + let all_tools = filtered_body.tools.as_ref().unwrap(); + let filtered_tools: Vec = all_tools + .iter() + .filter(|t| t.function.name == function.name) + .cloned() + .collect(); + filtered_body.tools = Some(filtered_tools); + std::borrow::Cow::Owned(filtered_body) + } + _ => std::borrow::Cow::Borrowed(body), // No filtering needed, use original + } +} + +/// Process chat messages and apply template (shared by both routers) +/// Requires HuggingFace tokenizer with chat template support +pub fn process_chat_messages( + request: &ChatCompletionRequest, + tokenizer: &dyn Tokenizer, +) -> Result { + // Use the tokenizer's chat template - we require HuggingFace tokenizer for gRPC + let formatted_text = if let Some(hf_tokenizer) = + tokenizer.as_any().downcast_ref::() + { + // Get content format and transform messages accordingly + let content_format = hf_tokenizer.chat_template_content_format(); + let mut transformed_messages = process_content_format(&request.messages, content_format)?; + + // Process tool call arguments in assistant messages + process_tool_call_arguments(&mut transformed_messages)?; + + // Convert tools to JSON values for template processing + let tools_json: Option> = request + .tools + .as_ref() + .map(|tools| { + tools + .iter() + .map(serde_json::to_value) + .collect::, _>>() + }) + .transpose() + .map_err(|e| format!("Failed to serialize tools: {}", e))?; + + // Build template kwargs, merging reasoning_effort if present + let mut combined_template_kwargs = HashMap::new(); + + // Add reasoning_effort if present (like Python does) + if let Some(reasoning_effort) = &request.reasoning_effort { + combined_template_kwargs.insert( + "reasoning_effort".to_string(), + Value::String(reasoning_effort.clone()), + ); + } + + // Add any additional template kwargs from request + if let Some(template_kwargs) = &request.chat_template_kwargs { + for (key, value) in template_kwargs { + combined_template_kwargs.insert(key.clone(), value.clone()); + } + } + + let final_template_kwargs = if combined_template_kwargs.is_empty() { + None + } else { + Some(&combined_template_kwargs) + }; + + let params = ChatTemplateParams { + add_generation_prompt: true, + continue_final_message: request.continue_final_message, + tools: tools_json.as_deref(), + template_kwargs: final_template_kwargs, + ..Default::default() + }; + + // Handle assistant prefix for continue_final_message + let assistant_prefix = if request.continue_final_message + && !transformed_messages.is_empty() + && transformed_messages + .last() + .and_then(|msg| msg.get("role")) + .and_then(|v| v.as_str()) + == Some("assistant") + { + // Pop the last message to handle it separately + let last_msg = transformed_messages.pop().unwrap(); + last_msg + .get("content") + .and_then(|v| v.as_str()) + .map(|s| s.to_string()) + } else { + None + }; + + // Apply chat template with the (now possibly shorter) list of messages + let rendered = hf_tokenizer + .apply_chat_template(&transformed_messages, params) + .map_err(|e| format!("Failed to apply chat template: {}", e))?; + + // Append assistant prefix if we have one + if let Some(prefix) = assistant_prefix { + format!("{}{}", rendered, prefix) + } else { + rendered + } + } else { + return Err( + "gRPC router requires HuggingFace tokenizer with chat template support".to_string(), + ); + }; + + // Placeholder for multimodal inputs + let multimodal_inputs = None; + + Ok(ProcessedMessages { + text: formatted_text, + multimodal_inputs, + stop_sequences: request.stop.clone(), + }) +} + +/// Error response helpers (shared between regular and PD routers) +pub fn internal_error_static(msg: &'static str) -> Response { + error!("{}", msg); + (StatusCode::INTERNAL_SERVER_ERROR, msg).into_response() +} + +pub fn internal_error_message(message: String) -> Response { + error!("{}", message); + (StatusCode::INTERNAL_SERVER_ERROR, message).into_response() +} + +/// Create a StopSequenceDecoder from stop parameters +pub fn create_stop_decoder( + tokenizer: &Arc, + stop: Option<&StringOrArray>, + stop_token_ids: Option<&Vec>, + skip_special_tokens: bool, + no_stop_trim: bool, +) -> StopSequenceDecoder { + use crate::tokenizer::stop::StopSequenceDecoderBuilder; + + // Extract stop sequences + let stop_sequences: Vec = match stop { + Some(StringOrArray::String(s)) => vec![s.clone()], + Some(StringOrArray::Array(arr)) => arr.clone(), + None => vec![], + }; + + // Build stop sequence decoder + let mut builder = + StopSequenceDecoderBuilder::new(tokenizer.clone()).skip_special_tokens(skip_special_tokens); + + // Add stop sequences (visible if no_stop_trim is true, hidden otherwise) + for seq in stop_sequences { + builder = if no_stop_trim { + builder.visible_stop_sequence(seq) + } else { + builder.stop_sequence(seq) + }; + } + + // Add stop token IDs (visible if no_stop_trim is true, hidden otherwise) + if let Some(token_ids) = stop_token_ids { + for &token_id in token_ids { + builder = if no_stop_trim { + builder.visible_stop_token(token_id) + } else { + builder.stop_token(token_id) + }; + } + } + + builder.build() +} + +/// Parse tool calls from JSON schema constrained response +pub fn parse_json_schema_response( + processed_text: &str, + tool_choice: &Option, +) -> (Option>, String) { + match tool_choice { + Some(ToolChoice::Function { function, .. }) => { + // Specific function: Parse parameters directly + match serde_json::from_str::(processed_text) { + Ok(params) => { + let tool_call = ToolCall { + id: format!("call_{}", uuid::Uuid::new_v4()), + tool_type: "function".to_string(), + function: FunctionCallResponse { + name: function.name.clone(), + arguments: Some( + serde_json::to_string(¶ms).unwrap_or_else(|_| "{}".to_string()), + ), + }, + }; + (Some(vec![tool_call]), String::new()) + } + Err(e) => { + error!("Failed to parse specific function parameters: {}", e); + (None, processed_text.to_string()) + } + } + } + Some(ToolChoice::Value(ToolChoiceValue::Required)) + | Some(ToolChoice::AllowedTools { .. }) => { + // Required mode: Parse array of tool calls + match serde_json::from_str::>(processed_text) { + Ok(parsed_array) => { + let spec_tool_calls: Vec = parsed_array + .into_iter() + .enumerate() + .filter_map(|(i, item)| { + let obj = item.as_object()?; + let name = obj.get("name")?.as_str()?.to_string(); + let parameters = obj.get("parameters")?; + + Some(ToolCall { + id: format!("call_{}_{}", i, uuid::Uuid::new_v4()), + tool_type: "function".to_string(), + function: FunctionCallResponse { + name, + arguments: Some( + serde_json::to_string(parameters) + .unwrap_or_else(|_| "{}".to_string()), + ), + }, + }) + }) + .collect(); + (Some(spec_tool_calls), String::new()) + } + Err(e) => { + error!("Failed to parse required tool call array: {}", e); + (None, processed_text.to_string()) + } + } + } + _ => (None, processed_text.to_string()), + } +} + +/// Collect responses from a gRPC stream +/// +/// This helper processes a gRPC GenerateResponse stream and collects all Complete responses. +/// Used by both regular and PD routers for non-streaming requests. +/// +/// # Arguments +/// * `stream` - The gRPC response stream to consume +/// * `worker_name` - Name for logging (e.g., "Prefill", "Decode", "Worker") +/// +/// # Returns +/// * `Ok(Vec)` - All complete responses collected from the stream +/// * `Err(Response)` - Error response if the stream fails or returns an error +pub async fn collect_stream_responses( + mut stream: Streaming, + worker_name: &str, +) -> Result, Response> { + use proto::generate_response::Response::*; + + let mut all_responses = Vec::new(); + + while let Some(response) = stream.next().await { + match response { + Ok(gen_response) => { + match gen_response.response { + Some(Complete(complete)) => { + debug!( + "{} completed: prompt_tokens={}, completion_tokens={}, finish_reason={}", + worker_name, complete.prompt_tokens, complete.completion_tokens, complete.finish_reason + ); + all_responses.push(complete); + } + Some(Error(err)) => { + error!("{} error: {}", worker_name, err.message); + return Err(internal_error_message(format!( + "{} generation failed: {}", + worker_name, err.message + ))); + } + Some(Chunk(chunk)) => { + debug!("{} chunk: {} tokens", worker_name, chunk.token_ids.len()); + } + None => { + debug!("{}: empty response", worker_name); + } + } + } + Err(e) => { + error!("{} stream error: {:?}", worker_name, e); + return Err(internal_error_message(format!( + "{} stream failed: {}", + worker_name, e + ))); + } + } + } + + debug!("{} stream closed", worker_name); + Ok(all_responses) +} + +/// Count the number of tool calls in the request message history +/// This is used for KimiK2 format which needs globally unique indices +pub fn get_history_tool_calls_count(request: &ChatCompletionRequest) -> usize { + request + .messages + .iter() + .filter_map(|msg| { + if let ChatMessage::Assistant { tool_calls, .. } = msg { + tool_calls.as_ref().map(|calls| calls.len()) + } else { + None + } + }) + .sum() +} + +/// Generate a tool call ID based on model format +/// +/// # Arguments +/// * `model` - Model name to determine ID format +/// * `tool_name` - Name of the tool being called +/// * `tool_index` - Index of this tool call within the current message +/// * `history_count` - Number of tool calls in previous messages +/// +/// # Returns +/// A unique ID string. KimiK2 uses `functions.{name}:{global_index}`, others use `call_{uuid}` +pub fn generate_tool_call_id( + model: &str, + tool_name: &str, + tool_index: usize, + history_count: usize, +) -> String { + if model.to_lowercase().contains("kimi") { + // KimiK2 format: functions.{name}:{global_index} + format!("functions.{}:{}", tool_name, history_count + tool_index) + } else { + // Standard OpenAI format: call_{24-char-uuid} + format!("call_{}", &Uuid::new_v4().simple().to_string()[..24]) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::protocols::spec::{ChatMessage, ContentPart, ImageUrl, UserMessageContent}; + use crate::tokenizer::chat_template::ChatTemplateContentFormat; + use serde_json::json; + + #[test] + fn test_transform_messages_string_format() { + let messages = vec![ChatMessage::User { + role: "user".to_string(), + content: UserMessageContent::Parts(vec![ + ContentPart::Text { + text: "Hello".to_string(), + }, + ContentPart::ImageUrl { + image_url: ImageUrl { + url: "https://example.com/image.jpg".to_string(), + detail: None, + }, + }, + ContentPart::Text { + text: "World".to_string(), + }, + ]), + name: None, + }]; + + let result = process_content_format(&messages, ChatTemplateContentFormat::String).unwrap(); + + assert_eq!(result.len(), 1); + let transformed_message = &result[0]; + + // Should flatten multimodal content to text only + assert_eq!( + transformed_message["content"].as_str().unwrap(), + "Hello World" + ); + assert_eq!(transformed_message["role"].as_str().unwrap(), "user"); + } + + #[test] + fn test_transform_messages_openai_format() { + let messages = vec![ChatMessage::User { + role: "user".to_string(), + content: UserMessageContent::Parts(vec![ + ContentPart::Text { + text: "Describe this image:".to_string(), + }, + ContentPart::ImageUrl { + image_url: ImageUrl { + url: "https://example.com/image.jpg".to_string(), + detail: Some("high".to_string()), + }, + }, + ]), + name: None, + }]; + + let result = process_content_format(&messages, ChatTemplateContentFormat::OpenAI).unwrap(); + + assert_eq!(result.len(), 1); + let transformed_message = &result[0]; + + // Should replace media URLs with simple type placeholders + let content_array = transformed_message["content"].as_array().unwrap(); + assert_eq!(content_array.len(), 2); + + // Text part should remain unchanged + assert_eq!(content_array[0]["type"], "text"); + assert_eq!(content_array[0]["text"], "Describe this image:"); + + // Image part should be replaced with simple type placeholder + assert_eq!(content_array[1], json!({"type": "image"})); + } + + #[test] + fn test_transform_messages_simple_string_content() { + let messages = vec![ChatMessage::User { + role: "user".to_string(), + content: UserMessageContent::Text("Simple text message".to_string()), + name: None, + }]; + + let result = process_content_format(&messages, ChatTemplateContentFormat::String).unwrap(); + + assert_eq!(result.len(), 1); + let transformed_message = &result[0]; + + // Simple string content should remain unchanged + assert_eq!( + transformed_message["content"].as_str().unwrap(), + "Simple text message" + ); + } + + #[test] + fn test_transform_messages_multiple_messages() { + let messages = vec![ + ChatMessage::System { + role: "system".to_string(), + content: "System prompt".to_string(), + name: None, + }, + ChatMessage::User { + role: "user".to_string(), + content: UserMessageContent::Parts(vec![ + ContentPart::Text { + text: "User message".to_string(), + }, + ContentPart::ImageUrl { + image_url: ImageUrl { + url: "https://example.com/image.jpg".to_string(), + detail: None, + }, + }, + ]), + name: None, + }, + ]; + + let result = process_content_format(&messages, ChatTemplateContentFormat::String).unwrap(); + + assert_eq!(result.len(), 2); + + // System message should remain unchanged + assert_eq!(result[0]["role"].as_str().unwrap(), "system"); + assert_eq!(result[0]["content"].as_str().unwrap(), "System prompt"); + + // User message should be flattened to text only + assert_eq!(result[1]["role"].as_str().unwrap(), "user"); + assert_eq!(result[1]["content"].as_str().unwrap(), "User message"); + } + + #[test] + fn test_transform_messages_empty_text_parts() { + let messages = vec![ChatMessage::User { + role: "user".to_string(), + content: UserMessageContent::Parts(vec![ContentPart::ImageUrl { + image_url: ImageUrl { + url: "https://example.com/image.jpg".to_string(), + detail: None, + }, + }]), + name: None, + }]; + + let result = process_content_format(&messages, ChatTemplateContentFormat::String).unwrap(); + + assert_eq!(result.len(), 1); + let transformed_message = &result[0]; + + // Should keep original multimodal content when no text parts exist + assert!(transformed_message["content"].is_array()); + } + + #[test] + fn test_transform_messages_mixed_content_types() { + let messages = vec![ + ChatMessage::User { + role: "user".to_string(), + content: UserMessageContent::Text("Plain text".to_string()), + name: None, + }, + ChatMessage::User { + role: "user".to_string(), + content: UserMessageContent::Parts(vec![ + ContentPart::Text { + text: "With image".to_string(), + }, + ContentPart::ImageUrl { + image_url: ImageUrl { + url: "https://example.com/image.jpg".to_string(), + detail: Some("low".to_string()), + }, + }, + ]), + name: None, + }, + ]; + + let result_string = + process_content_format(&messages, ChatTemplateContentFormat::String).unwrap(); + + assert_eq!(result_string.len(), 2); + assert_eq!(result_string[0]["content"].as_str().unwrap(), "Plain text"); + assert_eq!(result_string[1]["content"].as_str().unwrap(), "With image"); + + let result_openai = + process_content_format(&messages, ChatTemplateContentFormat::OpenAI).unwrap(); + + assert_eq!(result_openai.len(), 2); + assert_eq!(result_openai[0]["content"].as_str().unwrap(), "Plain text"); + + let content_array = result_openai[1]["content"].as_array().unwrap(); + assert_eq!(content_array.len(), 2); + assert_eq!(content_array[0]["type"], "text"); + assert_eq!(content_array[1], json!({"type": "image"})); + } +}