[router] fix error response in pd_router (#9505)
Co-authored-by: bruce.xu <bruce.xu@gmicloud.ai>
This commit is contained in:
@@ -28,7 +28,7 @@ use axum::{
|
||||
use futures_util::StreamExt;
|
||||
use reqwest::Client;
|
||||
use serde::Serialize;
|
||||
use serde_json::Value;
|
||||
use serde_json::{json, Value};
|
||||
use std::collections::HashMap;
|
||||
use std::sync::{Arc, RwLock};
|
||||
use std::time::{Duration, Instant};
|
||||
@@ -808,6 +808,57 @@ impl PDRouter {
|
||||
.await
|
||||
}
|
||||
|
||||
async fn handle_decode_error_response(
|
||||
&self,
|
||||
res: reqwest::Response,
|
||||
context: &PDRequestContext,
|
||||
prefill: &dyn Worker,
|
||||
decode: &dyn Worker,
|
||||
) -> Response {
|
||||
let status = res.status();
|
||||
|
||||
if context.is_stream {
|
||||
// Handle streaming error response
|
||||
let response_headers = header_utils::preserve_response_headers(res.headers());
|
||||
let error_payload = match res.bytes().await {
|
||||
Ok(error_body) => {
|
||||
if let Ok(error_json) = serde_json::from_slice::<Value>(&error_body) {
|
||||
json!({ "message": error_json, "status": status.as_u16() })
|
||||
} else {
|
||||
json!({ "message": String::from_utf8_lossy(&error_body).to_string(), "status": status.as_u16() })
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
json!({ "message": format!("Decode server error: {}", e), "status": status.as_u16() })
|
||||
}
|
||||
};
|
||||
|
||||
let sse_data = format!(
|
||||
"data: {{'error': {}}}",
|
||||
serde_json::to_string(&error_payload).unwrap_or_default()
|
||||
);
|
||||
let error_stream = tokio_stream::once(Ok(axum::body::Bytes::from(sse_data)));
|
||||
|
||||
let decode_url = decode.url().to_string();
|
||||
self.create_streaming_response(
|
||||
error_stream,
|
||||
status,
|
||||
None,
|
||||
context.return_logprob,
|
||||
Some(decode_url),
|
||||
Some(response_headers),
|
||||
prefill,
|
||||
decode,
|
||||
)
|
||||
} else {
|
||||
// Handle non-streaming error response
|
||||
match res.bytes().await {
|
||||
Ok(error_body) => (status, error_body).into_response(),
|
||||
Err(e) => (status, format!("Decode server error: {}", e)).into_response(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Internal method that performs the actual dual dispatch (without retry logic)
|
||||
async fn execute_dual_dispatch_internal(
|
||||
&self,
|
||||
@@ -881,16 +932,9 @@ impl PDRouter {
|
||||
status
|
||||
);
|
||||
|
||||
// Return the error response from decode server
|
||||
match res.bytes().await {
|
||||
Ok(error_body) => {
|
||||
return (status, error_body).into_response();
|
||||
}
|
||||
Err(e) => {
|
||||
return (status, format!("Decode server error: {}", e))
|
||||
.into_response();
|
||||
}
|
||||
}
|
||||
return self
|
||||
.handle_decode_error_response(res, &context, prefill, decode)
|
||||
.await;
|
||||
}
|
||||
|
||||
// Process prefill response for logprobs
|
||||
@@ -1034,13 +1078,8 @@ impl PDRouter {
|
||||
status
|
||||
);
|
||||
|
||||
// Return the error response from decode server
|
||||
match res.bytes().await {
|
||||
Ok(error_body) => (status, error_body).into_response(),
|
||||
Err(e) => {
|
||||
(status, format!("Decode server error: {}", e)).into_response()
|
||||
}
|
||||
}
|
||||
self.handle_decode_error_response(res, &context, prefill, decode)
|
||||
.await
|
||||
} else if context.is_stream {
|
||||
// Streaming response without logprobs - direct passthrough
|
||||
let decode_url = decode.url().to_string();
|
||||
|
||||
Reference in New Issue
Block a user