[router] fix pd model completion request (#8303)
This commit is contained in:
@@ -420,6 +420,77 @@ impl PDRouter {
|
||||
.await
|
||||
}
|
||||
|
||||
// Route a completion request while preserving OpenAI format
|
||||
pub async fn route_completion(
|
||||
&self,
|
||||
client: &reqwest::Client,
|
||||
req: &HttpRequest,
|
||||
mut typed_req: CompletionRequest,
|
||||
route: &str,
|
||||
) -> HttpResponse {
|
||||
let start = Instant::now();
|
||||
|
||||
// Get stream flag and return_logprob flag before moving the request
|
||||
let is_stream = typed_req.stream;
|
||||
let return_logprob = typed_req.logprobs.is_some();
|
||||
|
||||
// Extract text for cache-aware routing from the typed request
|
||||
let request_text = match &typed_req.prompt {
|
||||
crate::openai_api_types::StringOrArray::String(s) => Some(s.as_str()),
|
||||
crate::openai_api_types::StringOrArray::Array(arr) => arr.first().map(|s| s.as_str()),
|
||||
};
|
||||
|
||||
// Select servers
|
||||
let (prefill, decode) = match self.select_pd_pair(client, request_text).await {
|
||||
Ok(pair) => pair,
|
||||
Err(e) => {
|
||||
error!("Failed to select PD pair: {}", e);
|
||||
RouterMetrics::record_pd_error("server_selection");
|
||||
return HttpResponse::ServiceUnavailable()
|
||||
.body(format!("No available servers: {}", e));
|
||||
}
|
||||
};
|
||||
|
||||
// Log routing decision
|
||||
info!(
|
||||
"PD routing: {} -> prefill={}, decode={}",
|
||||
route,
|
||||
prefill.url(),
|
||||
decode.url()
|
||||
);
|
||||
|
||||
// Add bootstrap info using the trait method
|
||||
if let Err(e) = typed_req.add_bootstrap_info(prefill.as_ref()) {
|
||||
error!("Failed to add bootstrap info: {}", e);
|
||||
RouterMetrics::record_pd_error("bootstrap_injection");
|
||||
return HttpResponse::InternalServerError()
|
||||
.body(format!("Bootstrap injection failed: {}", e));
|
||||
}
|
||||
|
||||
// Convert to JSON after bootstrap injection
|
||||
let json_with_bootstrap = match serde_json::to_value(&typed_req) {
|
||||
Ok(json) => json,
|
||||
Err(e) => {
|
||||
error!("Failed to serialize request: {}", e);
|
||||
return HttpResponse::InternalServerError().body("Failed to serialize request");
|
||||
}
|
||||
};
|
||||
|
||||
// Execute dual dispatch
|
||||
self.execute_dual_dispatch(
|
||||
client,
|
||||
req,
|
||||
json_with_bootstrap,
|
||||
route,
|
||||
prefill.as_ref(),
|
||||
decode.as_ref(),
|
||||
is_stream,
|
||||
return_logprob,
|
||||
start,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
// Execute the dual dispatch to prefill and decode servers
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
async fn execute_dual_dispatch(
|
||||
@@ -1302,23 +1373,12 @@ impl RouterTrait for PDRouter {
|
||||
req: &HttpRequest,
|
||||
body: serde_json::Value,
|
||||
) -> HttpResponse {
|
||||
match serde_json::from_value::<CompletionRequest>(body.clone()) {
|
||||
match serde_json::from_value::<CompletionRequest>(body) {
|
||||
Ok(openai_req) => {
|
||||
// Convert OpenAI format to PD format (CompletionRequest -> GenerateReqInput)
|
||||
let pd_req = openai_req.to_pd_request();
|
||||
PDRouter::route_generate(self, client, req, pd_req, "/v1/completions").await
|
||||
}
|
||||
Err(_) => {
|
||||
// If that fails, try to deserialize directly as PD format (for backwards compatibility)
|
||||
match serde_json::from_value::<GenerateReqInput>(body) {
|
||||
Ok(pd_req) => {
|
||||
PDRouter::route_generate(self, client, req, pd_req, "/v1/completions").await
|
||||
}
|
||||
Err(e) => {
|
||||
HttpResponse::BadRequest().body(format!("Invalid request format: {}", e))
|
||||
}
|
||||
}
|
||||
// Use the new method that preserves OpenAI format
|
||||
PDRouter::route_completion(self, client, req, openai_req, "/v1/completions").await
|
||||
}
|
||||
Err(e) => HttpResponse::BadRequest().body(format!("Invalid request format: {}", e)),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user