[router] fix pd model completion request (#8303)

This commit is contained in:
Simo Lin
2025-07-23 23:18:29 -07:00
committed by GitHub
parent 5dd0f870ab
commit f6e07f2796
6 changed files with 320 additions and 15 deletions

View File

@@ -420,6 +420,77 @@ impl PDRouter {
.await
}
// Route a completion request while preserving OpenAI format
pub async fn route_completion(
&self,
client: &reqwest::Client,
req: &HttpRequest,
mut typed_req: CompletionRequest,
route: &str,
) -> HttpResponse {
let start = Instant::now();
// Get stream flag and return_logprob flag before moving the request
let is_stream = typed_req.stream;
let return_logprob = typed_req.logprobs.is_some();
// Extract text for cache-aware routing from the typed request
let request_text = match &typed_req.prompt {
crate::openai_api_types::StringOrArray::String(s) => Some(s.as_str()),
crate::openai_api_types::StringOrArray::Array(arr) => arr.first().map(|s| s.as_str()),
};
// Select servers
let (prefill, decode) = match self.select_pd_pair(client, request_text).await {
Ok(pair) => pair,
Err(e) => {
error!("Failed to select PD pair: {}", e);
RouterMetrics::record_pd_error("server_selection");
return HttpResponse::ServiceUnavailable()
.body(format!("No available servers: {}", e));
}
};
// Log routing decision
info!(
"PD routing: {} -> prefill={}, decode={}",
route,
prefill.url(),
decode.url()
);
// Add bootstrap info using the trait method
if let Err(e) = typed_req.add_bootstrap_info(prefill.as_ref()) {
error!("Failed to add bootstrap info: {}", e);
RouterMetrics::record_pd_error("bootstrap_injection");
return HttpResponse::InternalServerError()
.body(format!("Bootstrap injection failed: {}", e));
}
// Convert to JSON after bootstrap injection
let json_with_bootstrap = match serde_json::to_value(&typed_req) {
Ok(json) => json,
Err(e) => {
error!("Failed to serialize request: {}", e);
return HttpResponse::InternalServerError().body("Failed to serialize request");
}
};
// Execute dual dispatch
self.execute_dual_dispatch(
client,
req,
json_with_bootstrap,
route,
prefill.as_ref(),
decode.as_ref(),
is_stream,
return_logprob,
start,
)
.await
}
// Execute the dual dispatch to prefill and decode servers
#[allow(clippy::too_many_arguments)]
async fn execute_dual_dispatch(
@@ -1302,23 +1373,12 @@ impl RouterTrait for PDRouter {
req: &HttpRequest,
body: serde_json::Value,
) -> HttpResponse {
match serde_json::from_value::<CompletionRequest>(body.clone()) {
match serde_json::from_value::<CompletionRequest>(body) {
Ok(openai_req) => {
// Convert OpenAI format to PD format (CompletionRequest -> GenerateReqInput)
let pd_req = openai_req.to_pd_request();
PDRouter::route_generate(self, client, req, pd_req, "/v1/completions").await
}
Err(_) => {
// If that fails, try to deserialize directly as PD format (for backwards compatibility)
match serde_json::from_value::<GenerateReqInput>(body) {
Ok(pd_req) => {
PDRouter::route_generate(self, client, req, pd_req, "/v1/completions").await
}
Err(e) => {
HttpResponse::BadRequest().body(format!("Invalid request format: {}", e))
}
}
// Use the new method that preserves OpenAI format
PDRouter::route_completion(self, client, req, openai_req, "/v1/completions").await
}
Err(e) => HttpResponse::BadRequest().body(format!("Invalid request format: {}", e)),
}
}