[router] fix pd model completion request (#8303)
This commit is contained in:
@@ -97,6 +97,7 @@ fn create_sample_completion_request() -> CompletionRequest {
|
|||||||
logit_bias: None,
|
logit_bias: None,
|
||||||
user: None,
|
user: None,
|
||||||
seed: None,
|
seed: None,
|
||||||
|
other: serde_json::Map::new(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -91,6 +91,10 @@ pub struct CompletionRequest {
|
|||||||
/// If specified, our system will make a best effort to sample deterministically
|
/// If specified, our system will make a best effort to sample deterministically
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
pub seed: Option<i64>,
|
pub seed: Option<i64>,
|
||||||
|
|
||||||
|
/// Additional fields including bootstrap info for PD routing
|
||||||
|
#[serde(flatten)]
|
||||||
|
pub other: serde_json::Map<String, serde_json::Value>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl GenerationRequest for CompletionRequest {
|
impl GenerationRequest for CompletionRequest {
|
||||||
|
|||||||
@@ -420,6 +420,77 @@ impl PDRouter {
|
|||||||
.await
|
.await
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Route a completion request while preserving OpenAI format
|
||||||
|
pub async fn route_completion(
|
||||||
|
&self,
|
||||||
|
client: &reqwest::Client,
|
||||||
|
req: &HttpRequest,
|
||||||
|
mut typed_req: CompletionRequest,
|
||||||
|
route: &str,
|
||||||
|
) -> HttpResponse {
|
||||||
|
let start = Instant::now();
|
||||||
|
|
||||||
|
// Get stream flag and return_logprob flag before moving the request
|
||||||
|
let is_stream = typed_req.stream;
|
||||||
|
let return_logprob = typed_req.logprobs.is_some();
|
||||||
|
|
||||||
|
// Extract text for cache-aware routing from the typed request
|
||||||
|
let request_text = match &typed_req.prompt {
|
||||||
|
crate::openai_api_types::StringOrArray::String(s) => Some(s.as_str()),
|
||||||
|
crate::openai_api_types::StringOrArray::Array(arr) => arr.first().map(|s| s.as_str()),
|
||||||
|
};
|
||||||
|
|
||||||
|
// Select servers
|
||||||
|
let (prefill, decode) = match self.select_pd_pair(client, request_text).await {
|
||||||
|
Ok(pair) => pair,
|
||||||
|
Err(e) => {
|
||||||
|
error!("Failed to select PD pair: {}", e);
|
||||||
|
RouterMetrics::record_pd_error("server_selection");
|
||||||
|
return HttpResponse::ServiceUnavailable()
|
||||||
|
.body(format!("No available servers: {}", e));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Log routing decision
|
||||||
|
info!(
|
||||||
|
"PD routing: {} -> prefill={}, decode={}",
|
||||||
|
route,
|
||||||
|
prefill.url(),
|
||||||
|
decode.url()
|
||||||
|
);
|
||||||
|
|
||||||
|
// Add bootstrap info using the trait method
|
||||||
|
if let Err(e) = typed_req.add_bootstrap_info(prefill.as_ref()) {
|
||||||
|
error!("Failed to add bootstrap info: {}", e);
|
||||||
|
RouterMetrics::record_pd_error("bootstrap_injection");
|
||||||
|
return HttpResponse::InternalServerError()
|
||||||
|
.body(format!("Bootstrap injection failed: {}", e));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert to JSON after bootstrap injection
|
||||||
|
let json_with_bootstrap = match serde_json::to_value(&typed_req) {
|
||||||
|
Ok(json) => json,
|
||||||
|
Err(e) => {
|
||||||
|
error!("Failed to serialize request: {}", e);
|
||||||
|
return HttpResponse::InternalServerError().body("Failed to serialize request");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Execute dual dispatch
|
||||||
|
self.execute_dual_dispatch(
|
||||||
|
client,
|
||||||
|
req,
|
||||||
|
json_with_bootstrap,
|
||||||
|
route,
|
||||||
|
prefill.as_ref(),
|
||||||
|
decode.as_ref(),
|
||||||
|
is_stream,
|
||||||
|
return_logprob,
|
||||||
|
start,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
}
|
||||||
|
|
||||||
// Execute the dual dispatch to prefill and decode servers
|
// Execute the dual dispatch to prefill and decode servers
|
||||||
#[allow(clippy::too_many_arguments)]
|
#[allow(clippy::too_many_arguments)]
|
||||||
async fn execute_dual_dispatch(
|
async fn execute_dual_dispatch(
|
||||||
@@ -1302,23 +1373,12 @@ impl RouterTrait for PDRouter {
|
|||||||
req: &HttpRequest,
|
req: &HttpRequest,
|
||||||
body: serde_json::Value,
|
body: serde_json::Value,
|
||||||
) -> HttpResponse {
|
) -> HttpResponse {
|
||||||
match serde_json::from_value::<CompletionRequest>(body.clone()) {
|
match serde_json::from_value::<CompletionRequest>(body) {
|
||||||
Ok(openai_req) => {
|
Ok(openai_req) => {
|
||||||
// Convert OpenAI format to PD format (CompletionRequest -> GenerateReqInput)
|
// Use the new method that preserves OpenAI format
|
||||||
let pd_req = openai_req.to_pd_request();
|
PDRouter::route_completion(self, client, req, openai_req, "/v1/completions").await
|
||||||
PDRouter::route_generate(self, client, req, pd_req, "/v1/completions").await
|
|
||||||
}
|
|
||||||
Err(_) => {
|
|
||||||
// If that fails, try to deserialize directly as PD format (for backwards compatibility)
|
|
||||||
match serde_json::from_value::<GenerateReqInput>(body) {
|
|
||||||
Ok(pd_req) => {
|
|
||||||
PDRouter::route_generate(self, client, req, pd_req, "/v1/completions").await
|
|
||||||
}
|
|
||||||
Err(e) => {
|
|
||||||
HttpResponse::BadRequest().body(format!("Invalid request format: {}", e))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
Err(e) => HttpResponse::BadRequest().body(format!("Invalid request format: {}", e)),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
// Essential PDLB types extracted for PD routing
|
// Essential PDLB types extracted for PD routing
|
||||||
|
|
||||||
use crate::core::{Worker, WorkerType};
|
use crate::core::{Worker, WorkerType};
|
||||||
|
use crate::openai_api_types::{CompletionRequest, StringOrArray};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use serde_json::Value;
|
use serde_json::Value;
|
||||||
|
|
||||||
@@ -233,3 +234,235 @@ impl Bootstrap for ChatReqInput {
|
|||||||
self.bootstrap_room = Some(bootstrap_room);
|
self.bootstrap_room = Some(bootstrap_room);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Bootstrap implementation for CompletionRequest to preserve OpenAI format
|
||||||
|
impl Bootstrap for CompletionRequest {
|
||||||
|
fn is_stream(&self) -> bool {
|
||||||
|
self.stream
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_batch_size(&self) -> Result<Option<usize>, String> {
|
||||||
|
if let StringOrArray::Array(prompts) = &self.prompt {
|
||||||
|
if prompts.is_empty() {
|
||||||
|
return Err("Batch prompt array is empty".to_string());
|
||||||
|
}
|
||||||
|
return Ok(Some(prompts.len()));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Single string prompt
|
||||||
|
Ok(None)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn set_bootstrap_info(
|
||||||
|
&mut self,
|
||||||
|
bootstrap_host: BootstrapHost,
|
||||||
|
bootstrap_port: BootstrapPort,
|
||||||
|
bootstrap_room: BootstrapRoom,
|
||||||
|
) {
|
||||||
|
// Insert bootstrap_host - it serializes correctly whether Single or Batch
|
||||||
|
if let Ok(host_value) = serde_json::to_value(&bootstrap_host) {
|
||||||
|
self.other.insert("bootstrap_host".to_string(), host_value);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Insert bootstrap_port - it serializes correctly whether Single or Batch
|
||||||
|
if let Ok(port_value) = serde_json::to_value(&bootstrap_port) {
|
||||||
|
self.other.insert("bootstrap_port".to_string(), port_value);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Insert bootstrap_room - it serializes correctly whether Single or Batch
|
||||||
|
if let Ok(room_value) = serde_json::to_value(&bootstrap_room) {
|
||||||
|
self.other.insert("bootstrap_room".to_string(), room_value);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod bootstrap_tests {
|
||||||
|
use super::*;
|
||||||
|
use crate::openai_api_types::StringOrArray;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_completion_batch_size_with_array_prompt() {
|
||||||
|
let req = CompletionRequest {
|
||||||
|
model: "test".to_string(),
|
||||||
|
prompt: StringOrArray::Array(vec!["prompt1".to_string(), "prompt2".to_string()]),
|
||||||
|
n: None,
|
||||||
|
other: serde_json::Map::new(),
|
||||||
|
suffix: None,
|
||||||
|
max_tokens: None,
|
||||||
|
temperature: None,
|
||||||
|
top_p: None,
|
||||||
|
stream: false,
|
||||||
|
stream_options: None,
|
||||||
|
logprobs: None,
|
||||||
|
echo: false,
|
||||||
|
stop: None,
|
||||||
|
presence_penalty: None,
|
||||||
|
frequency_penalty: None,
|
||||||
|
best_of: None,
|
||||||
|
logit_bias: None,
|
||||||
|
user: None,
|
||||||
|
seed: None,
|
||||||
|
};
|
||||||
|
|
||||||
|
// Should return batch size for array prompt
|
||||||
|
assert_eq!(req.get_batch_size().unwrap(), Some(2));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_completion_batch_size_with_single_prompt() {
|
||||||
|
let req = CompletionRequest {
|
||||||
|
model: "test".to_string(),
|
||||||
|
prompt: StringOrArray::String("single prompt".to_string()),
|
||||||
|
n: None,
|
||||||
|
other: serde_json::Map::new(),
|
||||||
|
suffix: None,
|
||||||
|
max_tokens: None,
|
||||||
|
temperature: None,
|
||||||
|
top_p: None,
|
||||||
|
stream: false,
|
||||||
|
stream_options: None,
|
||||||
|
logprobs: None,
|
||||||
|
echo: false,
|
||||||
|
stop: None,
|
||||||
|
presence_penalty: None,
|
||||||
|
frequency_penalty: None,
|
||||||
|
best_of: None,
|
||||||
|
logit_bias: None,
|
||||||
|
user: None,
|
||||||
|
seed: None,
|
||||||
|
};
|
||||||
|
|
||||||
|
// Should return None for single prompt
|
||||||
|
assert_eq!(req.get_batch_size().unwrap(), None);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_completion_batch_size_with_n_parameter() {
|
||||||
|
let req = CompletionRequest {
|
||||||
|
model: "test".to_string(),
|
||||||
|
prompt: StringOrArray::String("single prompt".to_string()),
|
||||||
|
n: Some(3),
|
||||||
|
other: serde_json::Map::new(),
|
||||||
|
suffix: None,
|
||||||
|
max_tokens: None,
|
||||||
|
temperature: None,
|
||||||
|
top_p: None,
|
||||||
|
stream: false,
|
||||||
|
stream_options: None,
|
||||||
|
logprobs: None,
|
||||||
|
echo: false,
|
||||||
|
stop: None,
|
||||||
|
presence_penalty: None,
|
||||||
|
frequency_penalty: None,
|
||||||
|
best_of: None,
|
||||||
|
logit_bias: None,
|
||||||
|
user: None,
|
||||||
|
seed: None,
|
||||||
|
};
|
||||||
|
|
||||||
|
// Should return None for single string prompt, even with n > 1
|
||||||
|
// SGLang handles n parameter differently than batch requests
|
||||||
|
assert_eq!(req.get_batch_size().unwrap(), None);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_completion_bootstrap_single_values() {
|
||||||
|
let mut req = CompletionRequest {
|
||||||
|
model: "test".to_string(),
|
||||||
|
prompt: StringOrArray::Array(vec!["prompt1".to_string(), "prompt2".to_string()]),
|
||||||
|
n: None,
|
||||||
|
other: serde_json::Map::new(),
|
||||||
|
suffix: None,
|
||||||
|
max_tokens: None,
|
||||||
|
temperature: None,
|
||||||
|
top_p: None,
|
||||||
|
stream: false,
|
||||||
|
stream_options: None,
|
||||||
|
logprobs: None,
|
||||||
|
echo: false,
|
||||||
|
stop: None,
|
||||||
|
presence_penalty: None,
|
||||||
|
frequency_penalty: None,
|
||||||
|
best_of: None,
|
||||||
|
logit_bias: None,
|
||||||
|
user: None,
|
||||||
|
seed: None,
|
||||||
|
};
|
||||||
|
|
||||||
|
// Set bootstrap info - should always use single values
|
||||||
|
req.set_bootstrap_info(
|
||||||
|
BootstrapHost::Single("test-server".to_string()),
|
||||||
|
BootstrapPort::Single(Some(5678)),
|
||||||
|
BootstrapRoom::Single(12345),
|
||||||
|
);
|
||||||
|
|
||||||
|
// Verify single values were created
|
||||||
|
assert!(req.other.get("bootstrap_host").unwrap().is_string());
|
||||||
|
assert!(req.other.get("bootstrap_port").unwrap().is_number());
|
||||||
|
assert!(req.other.get("bootstrap_room").unwrap().is_number());
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
req.other.get("bootstrap_host").unwrap().as_str().unwrap(),
|
||||||
|
"test-server"
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
req.other.get("bootstrap_port").unwrap().as_u64().unwrap(),
|
||||||
|
5678
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
req.other.get("bootstrap_room").unwrap().as_u64().unwrap(),
|
||||||
|
12345
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_completion_bootstrap_array_values() {
|
||||||
|
let mut req = CompletionRequest {
|
||||||
|
model: "test".to_string(),
|
||||||
|
prompt: StringOrArray::Array(vec!["prompt1".to_string(), "prompt2".to_string()]),
|
||||||
|
n: None,
|
||||||
|
other: serde_json::Map::new(),
|
||||||
|
suffix: None,
|
||||||
|
max_tokens: None,
|
||||||
|
temperature: None,
|
||||||
|
top_p: None,
|
||||||
|
stream: false,
|
||||||
|
stream_options: None,
|
||||||
|
logprobs: None,
|
||||||
|
echo: false,
|
||||||
|
stop: None,
|
||||||
|
presence_penalty: None,
|
||||||
|
frequency_penalty: None,
|
||||||
|
best_of: None,
|
||||||
|
logit_bias: None,
|
||||||
|
user: None,
|
||||||
|
seed: None,
|
||||||
|
};
|
||||||
|
|
||||||
|
// Set bootstrap info with arrays
|
||||||
|
req.set_bootstrap_info(
|
||||||
|
BootstrapHost::Batch(vec!["test-server".to_string(); 2]),
|
||||||
|
BootstrapPort::Batch(vec![Some(5678); 2]),
|
||||||
|
BootstrapRoom::Batch(vec![12345, 67890]),
|
||||||
|
);
|
||||||
|
|
||||||
|
// Verify arrays were created correctly
|
||||||
|
assert!(req.other.get("bootstrap_host").unwrap().is_array());
|
||||||
|
assert!(req.other.get("bootstrap_port").unwrap().is_array());
|
||||||
|
assert!(req.other.get("bootstrap_room").unwrap().is_array());
|
||||||
|
|
||||||
|
let hosts = req.other.get("bootstrap_host").unwrap().as_array().unwrap();
|
||||||
|
assert_eq!(hosts.len(), 2);
|
||||||
|
assert_eq!(hosts[0].as_str().unwrap(), "test-server");
|
||||||
|
|
||||||
|
let ports = req.other.get("bootstrap_port").unwrap().as_array().unwrap();
|
||||||
|
assert_eq!(ports.len(), 2);
|
||||||
|
assert_eq!(ports[0].as_u64().unwrap(), 5678);
|
||||||
|
|
||||||
|
let rooms = req.other.get("bootstrap_room").unwrap().as_array().unwrap();
|
||||||
|
assert_eq!(rooms.len(), 2);
|
||||||
|
assert_eq!(rooms[0].as_u64().unwrap(), 12345);
|
||||||
|
assert_eq!(rooms[1].as_u64().unwrap(), 67890);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -648,6 +648,7 @@ mod tests {
|
|||||||
user: None,
|
user: None,
|
||||||
seed: None,
|
seed: None,
|
||||||
suffix: None,
|
suffix: None,
|
||||||
|
other: serde_json::Map::new(),
|
||||||
};
|
};
|
||||||
|
|
||||||
let pd_req = req.to_pd_request();
|
let pd_req = req.to_pd_request();
|
||||||
@@ -687,6 +688,7 @@ mod tests {
|
|||||||
user: None,
|
user: None,
|
||||||
seed: None,
|
seed: None,
|
||||||
suffix: None,
|
suffix: None,
|
||||||
|
other: serde_json::Map::new(),
|
||||||
};
|
};
|
||||||
|
|
||||||
let pd_req = req.to_pd_request();
|
let pd_req = req.to_pd_request();
|
||||||
@@ -725,6 +727,7 @@ mod tests {
|
|||||||
user: Some("user123".to_string()),
|
user: Some("user123".to_string()),
|
||||||
seed: Some(42),
|
seed: Some(42),
|
||||||
suffix: Some("...".to_string()),
|
suffix: Some("...".to_string()),
|
||||||
|
other: serde_json::Map::new(),
|
||||||
};
|
};
|
||||||
|
|
||||||
let pd_req = req.to_pd_request();
|
let pd_req = req.to_pd_request();
|
||||||
@@ -768,6 +771,7 @@ mod tests {
|
|||||||
user: None,
|
user: None,
|
||||||
seed: None,
|
seed: None,
|
||||||
suffix: None,
|
suffix: None,
|
||||||
|
other: serde_json::Map::new(),
|
||||||
};
|
};
|
||||||
|
|
||||||
let pd_req = req.to_pd_request();
|
let pd_req = req.to_pd_request();
|
||||||
@@ -799,6 +803,7 @@ mod tests {
|
|||||||
user: None,
|
user: None,
|
||||||
seed: None,
|
seed: None,
|
||||||
suffix: None,
|
suffix: None,
|
||||||
|
other: serde_json::Map::new(),
|
||||||
};
|
};
|
||||||
|
|
||||||
let pd_req = req.to_pd_request();
|
let pd_req = req.to_pd_request();
|
||||||
|
|||||||
@@ -86,6 +86,7 @@ fn test_benchmark_request_creation() {
|
|||||||
logit_bias: None,
|
logit_bias: None,
|
||||||
user: None,
|
user: None,
|
||||||
seed: None,
|
seed: None,
|
||||||
|
other: serde_json::Map::new(),
|
||||||
};
|
};
|
||||||
|
|
||||||
// Test serialization works
|
// Test serialization works
|
||||||
@@ -181,6 +182,7 @@ fn test_benchmark_request_adaptation() {
|
|||||||
logit_bias: None,
|
logit_bias: None,
|
||||||
user: None,
|
user: None,
|
||||||
seed: None,
|
seed: None,
|
||||||
|
other: serde_json::Map::new(),
|
||||||
};
|
};
|
||||||
|
|
||||||
// Test PD adaptation (should not panic)
|
// Test PD adaptation (should not panic)
|
||||||
|
|||||||
Reference in New Issue
Block a user