diff --git a/sgl-router/src/routers/grpc/pd_router.rs b/sgl-router/src/routers/grpc/pd_router.rs index c3ac542ce..1d39d748f 100644 --- a/sgl-router/src/routers/grpc/pd_router.rs +++ b/sgl-router/src/routers/grpc/pd_router.rs @@ -194,8 +194,7 @@ impl GrpcPDRouter { let (original_text, token_ids) = match self.resolve_generate_input(body) { Ok(res) => res, Err(msg) => { - error!("Invalid generate request: {}", msg); - return (StatusCode::BAD_REQUEST, msg).into_response(); + return utils::bad_request_error(msg); } }; @@ -208,8 +207,7 @@ impl GrpcPDRouter { { Ok(pair) => pair, Err(e) => { - warn!("Failed to select PD worker pair: {}", e); - return (StatusCode::SERVICE_UNAVAILABLE, e).into_response(); + return utils::service_unavailable_error(e); } }; @@ -244,15 +242,13 @@ impl GrpcPDRouter { ) { Ok(req) => req, Err(e) => { - error!("Failed to build generate request: {}", e); - return (StatusCode::BAD_REQUEST, e).into_response(); + return utils::bad_request_error(e); } }; // Step 5: Inject bootstrap metadata if let Err(e) = Self::inject_bootstrap_metadata(&mut request, &*prefill_worker) { - error!("Failed to inject bootstrap metadata: {}", e); - return (StatusCode::INTERNAL_SERVER_ERROR, e).into_response(); + return utils::internal_error_message(e); } // Step 6: Get weight version for response metadata @@ -334,8 +330,7 @@ impl GrpcPDRouter { let processed_messages = match utils::process_chat_messages(&body_ref, &*self.tokenizer) { Ok(msgs) => msgs, Err(e) => { - error!("Failed to process chat messages: {}", e); - return (StatusCode::BAD_REQUEST, e.to_string()).into_response(); + return utils::bad_request_error(e.to_string()); } }; @@ -343,12 +338,7 @@ impl GrpcPDRouter { let encoding = match self.tokenizer.encode(&processed_messages.text) { Ok(encoding) => encoding, Err(e) => { - error!("Tokenization failed: {}", e); - return ( - StatusCode::INTERNAL_SERVER_ERROR, - format!("Tokenization failed: {}", e), - ) - .into_response(); + return utils::internal_error_message(format!("Tokenization failed: {}", e)); } }; @@ -368,8 +358,7 @@ impl GrpcPDRouter { { Ok(pair) => pair, Err(e) => { - warn!("Failed to select PD worker pair: {}", e); - return (StatusCode::SERVICE_UNAVAILABLE, e).into_response(); + return utils::service_unavailable_error(e); } }; @@ -402,19 +391,13 @@ impl GrpcPDRouter { ) { Ok(request) => request, Err(e) => { - error!("Failed to build gRPC request: {}", e); - return ( - StatusCode::BAD_REQUEST, - format!("Invalid request parameters: {}", e), - ) - .into_response(); + return utils::bad_request_error(format!("Invalid request parameters: {}", e)); } }; // Step 8: Inject bootstrap metadata into the request if let Err(e) = Self::inject_bootstrap_metadata(&mut request, &*prefill_worker) { - error!("Failed to inject bootstrap metadata: {}", e); - return (StatusCode::INTERNAL_SERVER_ERROR, e).into_response(); + return utils::internal_error_message(e); } // Step 9: Handle streaming vs non-streaming @@ -486,12 +469,10 @@ impl GrpcPDRouter { let prefill_stream = match prefill_result { Ok(s) => s, Err(e) => { - error!("Failed to start prefill generation: {}", e); - return ( - StatusCode::INTERNAL_SERVER_ERROR, - format!("Prefill worker failed to start: {}", e), - ) - .into_response(); + return utils::internal_error_message(format!( + "Prefill worker failed to start: {}", + e + )); } }; @@ -499,12 +480,10 @@ impl GrpcPDRouter { let decode_stream = match decode_result { Ok(s) => s, Err(e) => { - error!("Failed to start decode generation: {}", e); - return ( - StatusCode::INTERNAL_SERVER_ERROR, - format!("Decode worker failed to start: {}", e), - ) - .into_response(); + return utils::internal_error_message(format!( + "Decode worker failed to start: {}", + e + )); } }; @@ -592,12 +571,10 @@ impl GrpcPDRouter { let prefill_stream = match prefill_result { Ok(s) => s, Err(e) => { - error!("Failed to start prefill generation: {}", e); - return ( - StatusCode::INTERNAL_SERVER_ERROR, - format!("Prefill worker failed to start: {}", e), - ) - .into_response(); + return utils::internal_error_message(format!( + "Prefill worker failed to start: {}", + e + )); } }; @@ -605,12 +582,10 @@ impl GrpcPDRouter { let decode_stream = match decode_result { Ok(s) => s, Err(e) => { - error!("Failed to start decode generation: {}", e); - return ( - StatusCode::INTERNAL_SERVER_ERROR, - format!("Decode worker failed to start: {}", e), - ) - .into_response(); + return utils::internal_error_message(format!( + "Decode worker failed to start: {}", + e + )); } }; diff --git a/sgl-router/src/routers/grpc/router.rs b/sgl-router/src/routers/grpc/router.rs index 29f318c80..8dc2316c9 100644 --- a/sgl-router/src/routers/grpc/router.rs +++ b/sgl-router/src/routers/grpc/router.rs @@ -113,8 +113,7 @@ impl GrpcRouter { let processed_messages = match utils::process_chat_messages(&body_ref, &*self.tokenizer) { Ok(msgs) => msgs, Err(e) => { - error!("Failed to process chat messages: {}", e); - return (StatusCode::BAD_REQUEST, e.to_string()).into_response(); + return utils::bad_request_error(e.to_string()); } }; @@ -122,12 +121,7 @@ impl GrpcRouter { let encoding = match self.tokenizer.encode(&processed_messages.text) { Ok(encoding) => encoding, Err(e) => { - error!("Tokenization failed: {}", e); - return ( - StatusCode::INTERNAL_SERVER_ERROR, - format!("Tokenization failed: {}", e), - ) - .into_response(); + return utils::internal_error_message(format!("Tokenization failed: {}", e)); } }; @@ -145,8 +139,10 @@ impl GrpcRouter { { Some(w) => w, None => { - warn!("No available workers for model: {:?}", model_id); - return (StatusCode::SERVICE_UNAVAILABLE, "No available workers").into_response(); + return utils::service_unavailable_error(format!( + "No available workers for model: {:?}", + model_id + )); } }; @@ -170,12 +166,7 @@ impl GrpcRouter { ) { Ok(request) => request, Err(e) => { - error!("Failed to build gRPC request: {}", e); - return ( - StatusCode::BAD_REQUEST, - format!("Invalid request parameters: {}", e), - ) - .into_response(); + return utils::bad_request_error(format!("Invalid request parameters: {}", e)); } }; @@ -200,8 +191,7 @@ impl GrpcRouter { let (original_text, token_ids) = match self.resolve_generate_input(body) { Ok(res) => res, Err(msg) => { - error!("Invalid generate request: {}", msg); - return (StatusCode::BAD_REQUEST, msg).into_response(); + return utils::bad_request_error(msg); } }; @@ -211,8 +201,10 @@ impl GrpcRouter { let worker = match self.select_worker_for_request(model_id, original_text.as_deref()) { Some(w) => w, None => { - warn!("No available workers for model: {:?}", model_id); - return (StatusCode::SERVICE_UNAVAILABLE, "No available workers").into_response(); + return utils::service_unavailable_error(format!( + "No available workers for model: {:?}", + model_id + )); } }; @@ -238,8 +230,7 @@ impl GrpcRouter { ) { Ok(req) => req, Err(e) => { - error!("Failed to build generate request: {}", e); - return (StatusCode::BAD_REQUEST, e).into_response(); + return utils::bad_request_error(e); } }; @@ -405,16 +396,6 @@ impl GrpcRouter { Ok((text.to_string(), encoding.token_ids().to_vec())) } - fn internal_error_static(msg: &'static str) -> Response { - error!("{}", msg); - (StatusCode::INTERNAL_SERVER_ERROR, msg).into_response() - } - - fn internal_error_message(message: String) -> Response { - error!("{}", message); - (StatusCode::INTERNAL_SERVER_ERROR, message).into_response() - } - /// Count the number of tool calls in the request message history /// This is used for KimiK2 format which needs globally unique indices fn get_history_tool_calls_count(request: &ChatCompletionRequest) -> usize { @@ -740,12 +721,7 @@ impl GrpcRouter { let mut grpc_stream = match client.generate(request).await { Ok(stream) => stream, Err(e) => { - error!("Failed to start generation: {}", e); - return ( - StatusCode::INTERNAL_SERVER_ERROR, - format!("Generation failed: {}", e), - ) - .into_response(); + return utils::internal_error_message(format!("Generation failed: {}", e)); } }; @@ -1183,7 +1159,7 @@ impl GrpcRouter { let stream = match client.generate(request).await { Ok(s) => s, Err(e) => { - return Self::internal_error_message(format!("Failed to start generation: {}", e)) + return utils::internal_error_message(format!("Failed to start generation: {}", e)) } }; @@ -1193,7 +1169,7 @@ impl GrpcRouter { }; if all_responses.is_empty() { - return Self::internal_error_static("No responses from server"); + return utils::internal_error_static("No responses from server"); } // Process each response into a ChatChoice @@ -1212,7 +1188,7 @@ impl GrpcRouter { { Ok(choice) => choices.push(choice), Err(e) => { - return Self::internal_error_message(format!( + return utils::internal_error_message(format!( "Failed to process choice {}: {}", index, e )); @@ -1265,7 +1241,7 @@ impl GrpcRouter { let stream = match client.generate(request).await { Ok(stream) => stream, Err(e) => { - return Self::internal_error_message(format!("Failed to start generation: {}", e)) + return utils::internal_error_message(format!("Failed to start generation: {}", e)) } }; @@ -1276,7 +1252,7 @@ impl GrpcRouter { }; if responses.is_empty() { - return Self::internal_error_static("No completion received from scheduler"); + return utils::internal_error_static("No completion received from scheduler"); } // Create stop decoder from sampling params @@ -1298,7 +1274,10 @@ impl GrpcRouter { let outputs = match stop_decoder.process_tokens(&complete.output_ids) { Ok(outputs) => outputs, Err(e) => { - return Self::internal_error_message(format!("Failed to process tokens: {}", e)) + return utils::internal_error_message(format!( + "Failed to process tokens: {}", + e + )) } }; @@ -1377,7 +1356,7 @@ impl GrpcRouter { let stream = match client.generate(request).await { Ok(stream) => stream, Err(e) => { - return Self::internal_error_message(format!("Failed to start generation: {}", e)) + return utils::internal_error_message(format!("Failed to start generation: {}", e)) } }; diff --git a/sgl-router/src/routers/grpc/utils.rs b/sgl-router/src/routers/grpc/utils.rs index f8d8de60e..af8d30544 100644 --- a/sgl-router/src/routers/grpc/utils.rs +++ b/sgl-router/src/routers/grpc/utils.rs @@ -14,13 +14,14 @@ pub use crate::tokenizer::StopSequenceDecoder; use axum::{ http::StatusCode, response::{IntoResponse, Response}, + Json, }; use futures::StreamExt; use serde_json::{json, Map, Value}; use std::collections::HashMap; use std::sync::Arc; use tonic::codec::Streaming; -use tracing::{debug, error}; +use tracing::{debug, error, warn}; use uuid::Uuid; /// Get gRPC client from worker, returning appropriate error response on failure @@ -30,22 +31,8 @@ pub async fn get_grpc_client_from_worker( let client_arc = worker .get_grpc_client() .await - .map_err(|e| { - error!("Failed to get gRPC client from worker: {}", e); - ( - StatusCode::INTERNAL_SERVER_ERROR, - format!("Failed to get gRPC client: {}", e), - ) - .into_response() - })? - .ok_or_else(|| { - error!("Selected worker is not a gRPC worker"); - ( - StatusCode::INTERNAL_SERVER_ERROR, - "Selected worker is not configured for gRPC", - ) - .into_response() - })?; + .map_err(|e| internal_error_message(format!("Failed to get gRPC client: {}", e)))? + .ok_or_else(|| internal_error_static("Selected worker is not configured for gRPC"))?; let client = client_arc.lock().await.clone(); Ok(client) @@ -422,12 +409,62 @@ pub fn process_chat_messages( /// Error response helpers (shared between regular and PD routers) pub fn internal_error_static(msg: &'static str) -> Response { error!("{}", msg); - (StatusCode::INTERNAL_SERVER_ERROR, msg).into_response() + ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({ + "error": { + "message": msg, + "type": "internal_error", + "code": 500 + } + })), + ) + .into_response() } pub fn internal_error_message(message: String) -> Response { error!("{}", message); - (StatusCode::INTERNAL_SERVER_ERROR, message).into_response() + ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({ + "error": { + "message": message, + "type": "internal_error", + "code": 500 + } + })), + ) + .into_response() +} + +pub fn bad_request_error(message: String) -> Response { + error!("{}", message); + ( + StatusCode::BAD_REQUEST, + Json(json!({ + "error": { + "message": message, + "type": "invalid_request_error", + "code": 400 + } + })), + ) + .into_response() +} + +pub fn service_unavailable_error(message: String) -> Response { + warn!("{}", message); + ( + StatusCode::SERVICE_UNAVAILABLE, + Json(json!({ + "error": { + "message": message, + "type": "service_unavailable", + "code": 503 + } + })), + ) + .into_response() } /// Create a StopSequenceDecoder from stop parameters