From 8b30bec265d64880d270da2017849de2a6093a7f Mon Sep 17 00:00:00 2001 From: Bruce-x-1997 Date: Thu, 28 Aug 2025 10:10:55 +0800 Subject: [PATCH] [router] fix error response in pd_router (#9505) Co-authored-by: bruce.xu --- sgl-router/src/routers/pd_router.rs | 75 ++++++++++++++++++++++------- 1 file changed, 57 insertions(+), 18 deletions(-) diff --git a/sgl-router/src/routers/pd_router.rs b/sgl-router/src/routers/pd_router.rs index 42fd54598..9562c08e4 100644 --- a/sgl-router/src/routers/pd_router.rs +++ b/sgl-router/src/routers/pd_router.rs @@ -28,7 +28,7 @@ use axum::{ use futures_util::StreamExt; use reqwest::Client; use serde::Serialize; -use serde_json::Value; +use serde_json::{json, Value}; use std::collections::HashMap; use std::sync::{Arc, RwLock}; use std::time::{Duration, Instant}; @@ -808,6 +808,57 @@ impl PDRouter { .await } + async fn handle_decode_error_response( + &self, + res: reqwest::Response, + context: &PDRequestContext, + prefill: &dyn Worker, + decode: &dyn Worker, + ) -> Response { + let status = res.status(); + + if context.is_stream { + // Handle streaming error response + let response_headers = header_utils::preserve_response_headers(res.headers()); + let error_payload = match res.bytes().await { + Ok(error_body) => { + if let Ok(error_json) = serde_json::from_slice::(&error_body) { + json!({ "message": error_json, "status": status.as_u16() }) + } else { + json!({ "message": String::from_utf8_lossy(&error_body).to_string(), "status": status.as_u16() }) + } + } + Err(e) => { + json!({ "message": format!("Decode server error: {}", e), "status": status.as_u16() }) + } + }; + + let sse_data = format!( + "data: {{'error': {}}}", + serde_json::to_string(&error_payload).unwrap_or_default() + ); + let error_stream = tokio_stream::once(Ok(axum::body::Bytes::from(sse_data))); + + let decode_url = decode.url().to_string(); + self.create_streaming_response( + error_stream, + status, + None, + context.return_logprob, + Some(decode_url), + Some(response_headers), + prefill, + decode, + ) + } else { + // Handle non-streaming error response + match res.bytes().await { + Ok(error_body) => (status, error_body).into_response(), + Err(e) => (status, format!("Decode server error: {}", e)).into_response(), + } + } + } + // Internal method that performs the actual dual dispatch (without retry logic) async fn execute_dual_dispatch_internal( &self, @@ -881,16 +932,9 @@ impl PDRouter { status ); - // Return the error response from decode server - match res.bytes().await { - Ok(error_body) => { - return (status, error_body).into_response(); - } - Err(e) => { - return (status, format!("Decode server error: {}", e)) - .into_response(); - } - } + return self + .handle_decode_error_response(res, &context, prefill, decode) + .await; } // Process prefill response for logprobs @@ -1034,13 +1078,8 @@ impl PDRouter { status ); - // Return the error response from decode server - match res.bytes().await { - Ok(error_body) => (status, error_body).into_response(), - Err(e) => { - (status, format!("Decode server error: {}", e)).into_response() - } - } + self.handle_decode_error_response(res, &context, prefill, decode) + .await } else if context.is_stream { // Streaming response without logprobs - direct passthrough let decode_url = decode.url().to_string();