[router] fix pd model completion request (#8303)
This commit is contained in:
@@ -97,6 +97,7 @@ fn create_sample_completion_request() -> CompletionRequest {
|
||||
logit_bias: None,
|
||||
user: None,
|
||||
seed: None,
|
||||
other: serde_json::Map::new(),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -91,6 +91,10 @@ pub struct CompletionRequest {
|
||||
/// If specified, our system will make a best effort to sample deterministically
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub seed: Option<i64>,
|
||||
|
||||
/// Additional fields including bootstrap info for PD routing
|
||||
#[serde(flatten)]
|
||||
pub other: serde_json::Map<String, serde_json::Value>,
|
||||
}
|
||||
|
||||
impl GenerationRequest for CompletionRequest {
|
||||
|
||||
@@ -420,6 +420,77 @@ impl PDRouter {
|
||||
.await
|
||||
}
|
||||
|
||||
// Route a completion request while preserving OpenAI format
|
||||
pub async fn route_completion(
|
||||
&self,
|
||||
client: &reqwest::Client,
|
||||
req: &HttpRequest,
|
||||
mut typed_req: CompletionRequest,
|
||||
route: &str,
|
||||
) -> HttpResponse {
|
||||
let start = Instant::now();
|
||||
|
||||
// Get stream flag and return_logprob flag before moving the request
|
||||
let is_stream = typed_req.stream;
|
||||
let return_logprob = typed_req.logprobs.is_some();
|
||||
|
||||
// Extract text for cache-aware routing from the typed request
|
||||
let request_text = match &typed_req.prompt {
|
||||
crate::openai_api_types::StringOrArray::String(s) => Some(s.as_str()),
|
||||
crate::openai_api_types::StringOrArray::Array(arr) => arr.first().map(|s| s.as_str()),
|
||||
};
|
||||
|
||||
// Select servers
|
||||
let (prefill, decode) = match self.select_pd_pair(client, request_text).await {
|
||||
Ok(pair) => pair,
|
||||
Err(e) => {
|
||||
error!("Failed to select PD pair: {}", e);
|
||||
RouterMetrics::record_pd_error("server_selection");
|
||||
return HttpResponse::ServiceUnavailable()
|
||||
.body(format!("No available servers: {}", e));
|
||||
}
|
||||
};
|
||||
|
||||
// Log routing decision
|
||||
info!(
|
||||
"PD routing: {} -> prefill={}, decode={}",
|
||||
route,
|
||||
prefill.url(),
|
||||
decode.url()
|
||||
);
|
||||
|
||||
// Add bootstrap info using the trait method
|
||||
if let Err(e) = typed_req.add_bootstrap_info(prefill.as_ref()) {
|
||||
error!("Failed to add bootstrap info: {}", e);
|
||||
RouterMetrics::record_pd_error("bootstrap_injection");
|
||||
return HttpResponse::InternalServerError()
|
||||
.body(format!("Bootstrap injection failed: {}", e));
|
||||
}
|
||||
|
||||
// Convert to JSON after bootstrap injection
|
||||
let json_with_bootstrap = match serde_json::to_value(&typed_req) {
|
||||
Ok(json) => json,
|
||||
Err(e) => {
|
||||
error!("Failed to serialize request: {}", e);
|
||||
return HttpResponse::InternalServerError().body("Failed to serialize request");
|
||||
}
|
||||
};
|
||||
|
||||
// Execute dual dispatch
|
||||
self.execute_dual_dispatch(
|
||||
client,
|
||||
req,
|
||||
json_with_bootstrap,
|
||||
route,
|
||||
prefill.as_ref(),
|
||||
decode.as_ref(),
|
||||
is_stream,
|
||||
return_logprob,
|
||||
start,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
// Execute the dual dispatch to prefill and decode servers
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
async fn execute_dual_dispatch(
|
||||
@@ -1302,23 +1373,12 @@ impl RouterTrait for PDRouter {
|
||||
req: &HttpRequest,
|
||||
body: serde_json::Value,
|
||||
) -> HttpResponse {
|
||||
match serde_json::from_value::<CompletionRequest>(body.clone()) {
|
||||
match serde_json::from_value::<CompletionRequest>(body) {
|
||||
Ok(openai_req) => {
|
||||
// Convert OpenAI format to PD format (CompletionRequest -> GenerateReqInput)
|
||||
let pd_req = openai_req.to_pd_request();
|
||||
PDRouter::route_generate(self, client, req, pd_req, "/v1/completions").await
|
||||
}
|
||||
Err(_) => {
|
||||
// If that fails, try to deserialize directly as PD format (for backwards compatibility)
|
||||
match serde_json::from_value::<GenerateReqInput>(body) {
|
||||
Ok(pd_req) => {
|
||||
PDRouter::route_generate(self, client, req, pd_req, "/v1/completions").await
|
||||
}
|
||||
Err(e) => {
|
||||
HttpResponse::BadRequest().body(format!("Invalid request format: {}", e))
|
||||
}
|
||||
}
|
||||
// Use the new method that preserves OpenAI format
|
||||
PDRouter::route_completion(self, client, req, openai_req, "/v1/completions").await
|
||||
}
|
||||
Err(e) => HttpResponse::BadRequest().body(format!("Invalid request format: {}", e)),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
// Essential PDLB types extracted for PD routing
|
||||
|
||||
use crate::core::{Worker, WorkerType};
|
||||
use crate::openai_api_types::{CompletionRequest, StringOrArray};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
|
||||
@@ -233,3 +234,235 @@ impl Bootstrap for ChatReqInput {
|
||||
self.bootstrap_room = Some(bootstrap_room);
|
||||
}
|
||||
}
|
||||
|
||||
// Bootstrap implementation for CompletionRequest to preserve OpenAI format
|
||||
impl Bootstrap for CompletionRequest {
|
||||
fn is_stream(&self) -> bool {
|
||||
self.stream
|
||||
}
|
||||
|
||||
fn get_batch_size(&self) -> Result<Option<usize>, String> {
|
||||
if let StringOrArray::Array(prompts) = &self.prompt {
|
||||
if prompts.is_empty() {
|
||||
return Err("Batch prompt array is empty".to_string());
|
||||
}
|
||||
return Ok(Some(prompts.len()));
|
||||
}
|
||||
|
||||
// Single string prompt
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
fn set_bootstrap_info(
|
||||
&mut self,
|
||||
bootstrap_host: BootstrapHost,
|
||||
bootstrap_port: BootstrapPort,
|
||||
bootstrap_room: BootstrapRoom,
|
||||
) {
|
||||
// Insert bootstrap_host - it serializes correctly whether Single or Batch
|
||||
if let Ok(host_value) = serde_json::to_value(&bootstrap_host) {
|
||||
self.other.insert("bootstrap_host".to_string(), host_value);
|
||||
}
|
||||
|
||||
// Insert bootstrap_port - it serializes correctly whether Single or Batch
|
||||
if let Ok(port_value) = serde_json::to_value(&bootstrap_port) {
|
||||
self.other.insert("bootstrap_port".to_string(), port_value);
|
||||
}
|
||||
|
||||
// Insert bootstrap_room - it serializes correctly whether Single or Batch
|
||||
if let Ok(room_value) = serde_json::to_value(&bootstrap_room) {
|
||||
self.other.insert("bootstrap_room".to_string(), room_value);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod bootstrap_tests {
|
||||
use super::*;
|
||||
use crate::openai_api_types::StringOrArray;
|
||||
|
||||
#[test]
|
||||
fn test_completion_batch_size_with_array_prompt() {
|
||||
let req = CompletionRequest {
|
||||
model: "test".to_string(),
|
||||
prompt: StringOrArray::Array(vec!["prompt1".to_string(), "prompt2".to_string()]),
|
||||
n: None,
|
||||
other: serde_json::Map::new(),
|
||||
suffix: None,
|
||||
max_tokens: None,
|
||||
temperature: None,
|
||||
top_p: None,
|
||||
stream: false,
|
||||
stream_options: None,
|
||||
logprobs: None,
|
||||
echo: false,
|
||||
stop: None,
|
||||
presence_penalty: None,
|
||||
frequency_penalty: None,
|
||||
best_of: None,
|
||||
logit_bias: None,
|
||||
user: None,
|
||||
seed: None,
|
||||
};
|
||||
|
||||
// Should return batch size for array prompt
|
||||
assert_eq!(req.get_batch_size().unwrap(), Some(2));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_completion_batch_size_with_single_prompt() {
|
||||
let req = CompletionRequest {
|
||||
model: "test".to_string(),
|
||||
prompt: StringOrArray::String("single prompt".to_string()),
|
||||
n: None,
|
||||
other: serde_json::Map::new(),
|
||||
suffix: None,
|
||||
max_tokens: None,
|
||||
temperature: None,
|
||||
top_p: None,
|
||||
stream: false,
|
||||
stream_options: None,
|
||||
logprobs: None,
|
||||
echo: false,
|
||||
stop: None,
|
||||
presence_penalty: None,
|
||||
frequency_penalty: None,
|
||||
best_of: None,
|
||||
logit_bias: None,
|
||||
user: None,
|
||||
seed: None,
|
||||
};
|
||||
|
||||
// Should return None for single prompt
|
||||
assert_eq!(req.get_batch_size().unwrap(), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_completion_batch_size_with_n_parameter() {
|
||||
let req = CompletionRequest {
|
||||
model: "test".to_string(),
|
||||
prompt: StringOrArray::String("single prompt".to_string()),
|
||||
n: Some(3),
|
||||
other: serde_json::Map::new(),
|
||||
suffix: None,
|
||||
max_tokens: None,
|
||||
temperature: None,
|
||||
top_p: None,
|
||||
stream: false,
|
||||
stream_options: None,
|
||||
logprobs: None,
|
||||
echo: false,
|
||||
stop: None,
|
||||
presence_penalty: None,
|
||||
frequency_penalty: None,
|
||||
best_of: None,
|
||||
logit_bias: None,
|
||||
user: None,
|
||||
seed: None,
|
||||
};
|
||||
|
||||
// Should return None for single string prompt, even with n > 1
|
||||
// SGLang handles n parameter differently than batch requests
|
||||
assert_eq!(req.get_batch_size().unwrap(), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_completion_bootstrap_single_values() {
|
||||
let mut req = CompletionRequest {
|
||||
model: "test".to_string(),
|
||||
prompt: StringOrArray::Array(vec!["prompt1".to_string(), "prompt2".to_string()]),
|
||||
n: None,
|
||||
other: serde_json::Map::new(),
|
||||
suffix: None,
|
||||
max_tokens: None,
|
||||
temperature: None,
|
||||
top_p: None,
|
||||
stream: false,
|
||||
stream_options: None,
|
||||
logprobs: None,
|
||||
echo: false,
|
||||
stop: None,
|
||||
presence_penalty: None,
|
||||
frequency_penalty: None,
|
||||
best_of: None,
|
||||
logit_bias: None,
|
||||
user: None,
|
||||
seed: None,
|
||||
};
|
||||
|
||||
// Set bootstrap info - should always use single values
|
||||
req.set_bootstrap_info(
|
||||
BootstrapHost::Single("test-server".to_string()),
|
||||
BootstrapPort::Single(Some(5678)),
|
||||
BootstrapRoom::Single(12345),
|
||||
);
|
||||
|
||||
// Verify single values were created
|
||||
assert!(req.other.get("bootstrap_host").unwrap().is_string());
|
||||
assert!(req.other.get("bootstrap_port").unwrap().is_number());
|
||||
assert!(req.other.get("bootstrap_room").unwrap().is_number());
|
||||
|
||||
assert_eq!(
|
||||
req.other.get("bootstrap_host").unwrap().as_str().unwrap(),
|
||||
"test-server"
|
||||
);
|
||||
assert_eq!(
|
||||
req.other.get("bootstrap_port").unwrap().as_u64().unwrap(),
|
||||
5678
|
||||
);
|
||||
assert_eq!(
|
||||
req.other.get("bootstrap_room").unwrap().as_u64().unwrap(),
|
||||
12345
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_completion_bootstrap_array_values() {
|
||||
let mut req = CompletionRequest {
|
||||
model: "test".to_string(),
|
||||
prompt: StringOrArray::Array(vec!["prompt1".to_string(), "prompt2".to_string()]),
|
||||
n: None,
|
||||
other: serde_json::Map::new(),
|
||||
suffix: None,
|
||||
max_tokens: None,
|
||||
temperature: None,
|
||||
top_p: None,
|
||||
stream: false,
|
||||
stream_options: None,
|
||||
logprobs: None,
|
||||
echo: false,
|
||||
stop: None,
|
||||
presence_penalty: None,
|
||||
frequency_penalty: None,
|
||||
best_of: None,
|
||||
logit_bias: None,
|
||||
user: None,
|
||||
seed: None,
|
||||
};
|
||||
|
||||
// Set bootstrap info with arrays
|
||||
req.set_bootstrap_info(
|
||||
BootstrapHost::Batch(vec!["test-server".to_string(); 2]),
|
||||
BootstrapPort::Batch(vec![Some(5678); 2]),
|
||||
BootstrapRoom::Batch(vec![12345, 67890]),
|
||||
);
|
||||
|
||||
// Verify arrays were created correctly
|
||||
assert!(req.other.get("bootstrap_host").unwrap().is_array());
|
||||
assert!(req.other.get("bootstrap_port").unwrap().is_array());
|
||||
assert!(req.other.get("bootstrap_room").unwrap().is_array());
|
||||
|
||||
let hosts = req.other.get("bootstrap_host").unwrap().as_array().unwrap();
|
||||
assert_eq!(hosts.len(), 2);
|
||||
assert_eq!(hosts[0].as_str().unwrap(), "test-server");
|
||||
|
||||
let ports = req.other.get("bootstrap_port").unwrap().as_array().unwrap();
|
||||
assert_eq!(ports.len(), 2);
|
||||
assert_eq!(ports[0].as_u64().unwrap(), 5678);
|
||||
|
||||
let rooms = req.other.get("bootstrap_room").unwrap().as_array().unwrap();
|
||||
assert_eq!(rooms.len(), 2);
|
||||
assert_eq!(rooms[0].as_u64().unwrap(), 12345);
|
||||
assert_eq!(rooms[1].as_u64().unwrap(), 67890);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -648,6 +648,7 @@ mod tests {
|
||||
user: None,
|
||||
seed: None,
|
||||
suffix: None,
|
||||
other: serde_json::Map::new(),
|
||||
};
|
||||
|
||||
let pd_req = req.to_pd_request();
|
||||
@@ -687,6 +688,7 @@ mod tests {
|
||||
user: None,
|
||||
seed: None,
|
||||
suffix: None,
|
||||
other: serde_json::Map::new(),
|
||||
};
|
||||
|
||||
let pd_req = req.to_pd_request();
|
||||
@@ -725,6 +727,7 @@ mod tests {
|
||||
user: Some("user123".to_string()),
|
||||
seed: Some(42),
|
||||
suffix: Some("...".to_string()),
|
||||
other: serde_json::Map::new(),
|
||||
};
|
||||
|
||||
let pd_req = req.to_pd_request();
|
||||
@@ -768,6 +771,7 @@ mod tests {
|
||||
user: None,
|
||||
seed: None,
|
||||
suffix: None,
|
||||
other: serde_json::Map::new(),
|
||||
};
|
||||
|
||||
let pd_req = req.to_pd_request();
|
||||
@@ -799,6 +803,7 @@ mod tests {
|
||||
user: None,
|
||||
seed: None,
|
||||
suffix: None,
|
||||
other: serde_json::Map::new(),
|
||||
};
|
||||
|
||||
let pd_req = req.to_pd_request();
|
||||
|
||||
@@ -86,6 +86,7 @@ fn test_benchmark_request_creation() {
|
||||
logit_bias: None,
|
||||
user: None,
|
||||
seed: None,
|
||||
other: serde_json::Map::new(),
|
||||
};
|
||||
|
||||
// Test serialization works
|
||||
@@ -181,6 +182,7 @@ fn test_benchmark_request_adaptation() {
|
||||
logit_bias: None,
|
||||
user: None,
|
||||
seed: None,
|
||||
other: serde_json::Map::new(),
|
||||
};
|
||||
|
||||
// Test PD adaptation (should not panic)
|
||||
|
||||
Reference in New Issue
Block a user