[router] fix error response in pd_router (#9505)
Co-authored-by: bruce.xu <bruce.xu@gmicloud.ai>
This commit is contained in:
@@ -28,7 +28,7 @@ use axum::{
|
|||||||
use futures_util::StreamExt;
|
use futures_util::StreamExt;
|
||||||
use reqwest::Client;
|
use reqwest::Client;
|
||||||
use serde::Serialize;
|
use serde::Serialize;
|
||||||
use serde_json::Value;
|
use serde_json::{json, Value};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::sync::{Arc, RwLock};
|
use std::sync::{Arc, RwLock};
|
||||||
use std::time::{Duration, Instant};
|
use std::time::{Duration, Instant};
|
||||||
@@ -808,6 +808,57 @@ impl PDRouter {
|
|||||||
.await
|
.await
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn handle_decode_error_response(
|
||||||
|
&self,
|
||||||
|
res: reqwest::Response,
|
||||||
|
context: &PDRequestContext,
|
||||||
|
prefill: &dyn Worker,
|
||||||
|
decode: &dyn Worker,
|
||||||
|
) -> Response {
|
||||||
|
let status = res.status();
|
||||||
|
|
||||||
|
if context.is_stream {
|
||||||
|
// Handle streaming error response
|
||||||
|
let response_headers = header_utils::preserve_response_headers(res.headers());
|
||||||
|
let error_payload = match res.bytes().await {
|
||||||
|
Ok(error_body) => {
|
||||||
|
if let Ok(error_json) = serde_json::from_slice::<Value>(&error_body) {
|
||||||
|
json!({ "message": error_json, "status": status.as_u16() })
|
||||||
|
} else {
|
||||||
|
json!({ "message": String::from_utf8_lossy(&error_body).to_string(), "status": status.as_u16() })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
json!({ "message": format!("Decode server error: {}", e), "status": status.as_u16() })
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let sse_data = format!(
|
||||||
|
"data: {{'error': {}}}",
|
||||||
|
serde_json::to_string(&error_payload).unwrap_or_default()
|
||||||
|
);
|
||||||
|
let error_stream = tokio_stream::once(Ok(axum::body::Bytes::from(sse_data)));
|
||||||
|
|
||||||
|
let decode_url = decode.url().to_string();
|
||||||
|
self.create_streaming_response(
|
||||||
|
error_stream,
|
||||||
|
status,
|
||||||
|
None,
|
||||||
|
context.return_logprob,
|
||||||
|
Some(decode_url),
|
||||||
|
Some(response_headers),
|
||||||
|
prefill,
|
||||||
|
decode,
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
// Handle non-streaming error response
|
||||||
|
match res.bytes().await {
|
||||||
|
Ok(error_body) => (status, error_body).into_response(),
|
||||||
|
Err(e) => (status, format!("Decode server error: {}", e)).into_response(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Internal method that performs the actual dual dispatch (without retry logic)
|
// Internal method that performs the actual dual dispatch (without retry logic)
|
||||||
async fn execute_dual_dispatch_internal(
|
async fn execute_dual_dispatch_internal(
|
||||||
&self,
|
&self,
|
||||||
@@ -881,16 +932,9 @@ impl PDRouter {
|
|||||||
status
|
status
|
||||||
);
|
);
|
||||||
|
|
||||||
// Return the error response from decode server
|
return self
|
||||||
match res.bytes().await {
|
.handle_decode_error_response(res, &context, prefill, decode)
|
||||||
Ok(error_body) => {
|
.await;
|
||||||
return (status, error_body).into_response();
|
|
||||||
}
|
|
||||||
Err(e) => {
|
|
||||||
return (status, format!("Decode server error: {}", e))
|
|
||||||
.into_response();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Process prefill response for logprobs
|
// Process prefill response for logprobs
|
||||||
@@ -1034,13 +1078,8 @@ impl PDRouter {
|
|||||||
status
|
status
|
||||||
);
|
);
|
||||||
|
|
||||||
// Return the error response from decode server
|
self.handle_decode_error_response(res, &context, prefill, decode)
|
||||||
match res.bytes().await {
|
.await
|
||||||
Ok(error_body) => (status, error_body).into_response(),
|
|
||||||
Err(e) => {
|
|
||||||
(status, format!("Decode server error: {}", e)).into_response()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else if context.is_stream {
|
} else if context.is_stream {
|
||||||
// Streaming response without logprobs - direct passthrough
|
// Streaming response without logprobs - direct passthrough
|
||||||
let decode_url = decode.url().to_string();
|
let decode_url = decode.url().to_string();
|
||||||
|
|||||||
Reference in New Issue
Block a user