[router] PD Router Simplification and Reorganization (#8838)

This commit is contained in:
Simo Lin
2025-08-05 21:20:38 -07:00
committed by GitHub
parent ca47e24f5d
commit 8c7bb39dfb
8 changed files with 1220 additions and 2677 deletions

View File

@@ -0,0 +1,334 @@
// Bootstrap field injection for PD routing
// Directly injects bootstrap fields into JSON requests without intermediate type conversions
use crate::core::{Worker, WorkerType};
use crate::routers::pd_types::get_hostname;
use serde_json::{json, Value};
/// Inject bootstrap fields directly into a JSON request
/// This replaces the complex ToPdRequest -> Bootstrap trait pattern
pub fn inject_bootstrap_fields(json: &mut Value, worker: &dyn Worker) -> Result<(), String> {
let batch_size = extract_batch_size(json)?;
// Extract bootstrap port from prefill worker if it's a prefill type
let bootstrap_port = match worker.worker_type() {
WorkerType::Prefill { bootstrap_port } => bootstrap_port,
_ => None,
};
let hostname = get_hostname(worker.url());
if let Some(batch_size) = batch_size {
// Batch scenario - create arrays of bootstrap values
json["bootstrap_host"] = json!(vec![hostname; batch_size]);
json["bootstrap_port"] = json!(vec![bootstrap_port; batch_size]);
json["bootstrap_room"] = json!((0..batch_size)
.map(|_| {
// Generate a value in the range [0, 2^63 - 1] to match Python's random.randint(0, 2**63 - 1)
rand::random::<u64>() & (i64::MAX as u64)
})
.collect::<Vec<_>>());
} else {
// Single scenario - create single bootstrap values
json["bootstrap_host"] = json!(hostname);
json["bootstrap_port"] = json!(bootstrap_port);
json["bootstrap_room"] = json!(rand::random::<u64>() & (i64::MAX as u64));
}
Ok(())
}
/// Extract batch size from various JSON request formats
/// Handles chat completions, completions, and generate requests
fn extract_batch_size(json: &Value) -> Result<Option<usize>, String> {
// Check for chat completions 'n' parameter (number of choices)
if let Some(n) = json.get("n").and_then(|v| v.as_u64()) {
if n > 1 {
return Ok(Some(n as usize));
}
}
// Check for array prompts (completions API)
if let Some(prompt) = json.get("prompt") {
if let Some(arr) = prompt.as_array() {
if arr.is_empty() {
return Err("Batch prompt array is empty".to_string());
}
return Ok(Some(arr.len()));
}
}
// Check for array texts (generate API)
if let Some(text) = json.get("text") {
if let Some(arr) = text.as_array() {
if arr.is_empty() {
return Err("Batch text array is empty".to_string());
}
return Ok(Some(arr.len()));
}
}
// Check for batch input_ids (generate API)
if let Some(input_ids) = json.get("input_ids") {
if let Some(arr) = input_ids.as_array() {
if arr.is_empty() {
return Err("Batch input_ids array is empty".to_string());
}
return Ok(Some(arr.len()));
}
}
// No batch indicators found - single request
Ok(None)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::core::BasicWorker;
use serde_json::json;
fn create_test_worker() -> BasicWorker {
BasicWorker::new(
"http://test-server:8000".to_string(),
WorkerType::Prefill {
bootstrap_port: Some(5678),
},
)
}
#[test]
fn test_inject_bootstrap_single_request() {
let worker = create_test_worker();
let mut json = json!({
"model": "test-model",
"prompt": "Hello world",
"max_tokens": 100
});
let result = inject_bootstrap_fields(&mut json, &worker);
assert!(result.is_ok());
// Verify bootstrap fields were added
assert_eq!(json["bootstrap_host"], json!("test-server"));
assert_eq!(json["bootstrap_port"], json!(5678));
assert!(json["bootstrap_room"].is_number());
// Verify original fields preserved
assert_eq!(json["model"], json!("test-model"));
assert_eq!(json["prompt"], json!("Hello world"));
assert_eq!(json["max_tokens"], json!(100));
}
#[test]
fn test_inject_bootstrap_batch_prompt() {
let worker = create_test_worker();
let mut json = json!({
"model": "test-model",
"prompt": ["Hello", "World"],
"max_tokens": 100
});
let result = inject_bootstrap_fields(&mut json, &worker);
assert!(result.is_ok());
// Verify batch bootstrap fields
assert_eq!(
json["bootstrap_host"],
json!(["test-server", "test-server"])
);
assert_eq!(json["bootstrap_port"], json!([5678, 5678]));
let bootstrap_rooms = json["bootstrap_room"].as_array().unwrap();
assert_eq!(bootstrap_rooms.len(), 2);
for room in bootstrap_rooms {
assert!(room.is_number());
let room_val = room.as_u64().unwrap();
assert!(room_val <= i64::MAX as u64);
}
}
#[test]
fn test_inject_bootstrap_chat_n_parameter() {
let worker = create_test_worker();
let mut json = json!({
"model": "gpt-4",
"messages": [{"role": "user", "content": "Hello"}],
"n": 3
});
let result = inject_bootstrap_fields(&mut json, &worker);
assert!(result.is_ok());
// Verify batch bootstrap fields for n=3
let bootstrap_hosts = json["bootstrap_host"].as_array().unwrap();
assert_eq!(bootstrap_hosts.len(), 3);
assert_eq!(bootstrap_hosts[0], json!("test-server"));
let bootstrap_ports = json["bootstrap_port"].as_array().unwrap();
assert_eq!(bootstrap_ports.len(), 3);
assert_eq!(bootstrap_ports[0], json!(5678));
let bootstrap_rooms = json["bootstrap_room"].as_array().unwrap();
assert_eq!(bootstrap_rooms.len(), 3);
}
#[test]
fn test_inject_bootstrap_generate_text_array() {
let worker = create_test_worker();
let mut json = json!({
"text": ["First prompt", "Second prompt"],
"stream": false
});
let result = inject_bootstrap_fields(&mut json, &worker);
assert!(result.is_ok());
// Verify batch bootstrap fields
let bootstrap_hosts = json["bootstrap_host"].as_array().unwrap();
assert_eq!(bootstrap_hosts.len(), 2);
let bootstrap_rooms = json["bootstrap_room"].as_array().unwrap();
assert_eq!(bootstrap_rooms.len(), 2);
// Ensure room values are different (randomness)
assert_ne!(bootstrap_rooms[0], bootstrap_rooms[1]);
}
#[test]
fn test_inject_bootstrap_input_ids_array() {
let worker = create_test_worker();
let mut json = json!({
"input_ids": [[1, 2, 3], [4, 5, 6]],
"stream": false
});
let result = inject_bootstrap_fields(&mut json, &worker);
assert!(result.is_ok());
// Verify batch bootstrap fields
let bootstrap_hosts = json["bootstrap_host"].as_array().unwrap();
assert_eq!(bootstrap_hosts.len(), 2);
}
#[test]
fn test_extract_batch_size_empty_array_error() {
let json = json!({
"prompt": [],
"model": "test"
});
let result = extract_batch_size(&json);
assert!(result.is_err());
assert!(result.unwrap_err().contains("empty"));
}
#[test]
fn test_extract_batch_size_single_requests() {
// Single string prompt
let json = json!({
"prompt": "Hello world",
"model": "test"
});
assert_eq!(extract_batch_size(&json).unwrap(), None);
// Single text
let json = json!({
"text": "Hello world",
"stream": false
});
assert_eq!(extract_batch_size(&json).unwrap(), None);
// Chat with n=1 (default)
let json = json!({
"messages": [{"role": "user", "content": "Hello"}],
"n": 1
});
assert_eq!(extract_batch_size(&json).unwrap(), None);
// Chat without n parameter
let json = json!({
"messages": [{"role": "user", "content": "Hello"}]
});
assert_eq!(extract_batch_size(&json).unwrap(), None);
}
#[test]
fn test_inject_bootstrap_preserves_sglang_fields() {
let worker = create_test_worker();
let mut json = json!({
"model": "test-model",
"prompt": "Hello",
// SGLang extensions should be preserved
"top_k": 40,
"min_p": 0.05,
"repetition_penalty": 1.1,
"regex": "test_pattern",
"lora_path": "test.bin",
"no_stop_trim": true,
"ignore_eos": false
});
let result = inject_bootstrap_fields(&mut json, &worker);
assert!(result.is_ok());
// Verify bootstrap fields added
assert!(json.get("bootstrap_host").is_some());
assert!(json.get("bootstrap_port").is_some());
assert!(json.get("bootstrap_room").is_some());
// Verify all SGLang fields preserved
assert_eq!(json["top_k"], json!(40));
assert_eq!(json["min_p"], json!(0.05));
assert_eq!(json["repetition_penalty"], json!(1.1));
assert_eq!(json["regex"], json!("test_pattern"));
assert_eq!(json["lora_path"], json!("test.bin"));
assert_eq!(json["no_stop_trim"], json!(true));
assert_eq!(json["ignore_eos"], json!(false));
}
#[test]
fn test_bootstrap_room_range() {
let worker = create_test_worker();
// Test single request room generation
for _ in 0..1000 {
let mut json = json!({"prompt": "test"});
inject_bootstrap_fields(&mut json, &worker).unwrap();
let room = json["bootstrap_room"].as_u64().unwrap();
assert!(room <= i64::MAX as u64, "Room {} exceeds i64::MAX", room);
}
// Test batch request room generation
for _ in 0..100 {
let mut json = json!({"prompt": ["test1", "test2"]});
inject_bootstrap_fields(&mut json, &worker).unwrap();
let rooms = json["bootstrap_room"].as_array().unwrap();
for room_val in rooms {
let room = room_val.as_u64().unwrap();
assert!(room <= i64::MAX as u64, "Room {} exceeds i64::MAX", room);
}
}
}
#[test]
fn test_worker_without_bootstrap_port() {
let worker = BasicWorker::new(
"http://decode-only:8000".to_string(),
WorkerType::Decode, // No bootstrap port
);
let mut json = json!({
"prompt": "Hello world"
});
let result = inject_bootstrap_fields(&mut json, &worker);
assert!(result.is_ok());
// Verify bootstrap fields with null port
assert_eq!(json["bootstrap_host"], json!("decode-only"));
assert_eq!(json["bootstrap_port"], json!(null));
assert!(json["bootstrap_room"].is_number());
}
}

View File

@@ -11,10 +11,10 @@ use std::fmt::Debug;
use crate::openai_api_types::{ChatCompletionRequest, CompletionRequest, GenerateRequest};
pub mod bootstrap_injector;
pub mod factory;
pub mod pd_router;
pub mod pd_types;
pub mod request_adapter;
pub mod router;
pub use factory::RouterFactory;

File diff suppressed because it is too large Load Diff

View File

@@ -1,10 +1,3 @@
// Essential PDLB types extracted for PD routing
use crate::core::{Worker, WorkerType};
use crate::openai_api_types::{CompletionRequest, StringOrArray};
use serde::{Deserialize, Serialize};
use serde_json::Value;
// Custom error type for PD router operations
#[derive(Debug, thiserror::Error)]
pub enum PDRouterError {
@@ -58,428 +51,3 @@ pub enum PDSelectionPolicy {
balance_rel_threshold: f32,
},
}
// Bootstrap types from PDLB
#[derive(Debug, Deserialize, Serialize, PartialEq)]
#[serde(untagged)]
pub enum SingleOrBatch<T> {
Single(T),
Batch(Vec<T>),
}
pub type InputIds = SingleOrBatch<Vec<i32>>;
pub type InputText = SingleOrBatch<String>;
pub type BootstrapHost = SingleOrBatch<String>;
pub type BootstrapPort = SingleOrBatch<Option<u16>>;
pub type BootstrapRoom = SingleOrBatch<u64>;
// Bootstrap trait for request handling
pub trait Bootstrap: Send + Sync {
fn is_stream(&self) -> bool;
fn get_batch_size(&self) -> Result<Option<usize>, String>;
fn set_bootstrap_info(
&mut self,
bootstrap_host: BootstrapHost,
bootstrap_port: BootstrapPort,
bootstrap_room: BootstrapRoom,
);
fn add_bootstrap_info(&mut self, prefill_worker: &dyn Worker) -> Result<(), String> {
let batch_size = self.get_batch_size()?;
// Extract bootstrap port from prefill worker if it's a prefill type
let bootstrap_port = match prefill_worker.worker_type() {
WorkerType::Prefill { bootstrap_port } => bootstrap_port,
_ => None,
};
let hostname = get_hostname(prefill_worker.url());
if let Some(batch_size) = batch_size {
self.set_bootstrap_info(
BootstrapHost::Batch(vec![hostname; batch_size]),
BootstrapPort::Batch(vec![bootstrap_port; batch_size]),
// Use high-quality random numbers to minimize collision risk
BootstrapRoom::Batch(
(0..batch_size)
.map(|_| {
// Generate a value in the range [0, 2^63 - 1] to match Python's random.randint(0, 2**63 - 1)
rand::random::<u64>() & (i64::MAX as u64)
})
.collect(),
),
);
} else {
self.set_bootstrap_info(
BootstrapHost::Single(hostname),
BootstrapPort::Single(bootstrap_port),
BootstrapRoom::Single(
// Generate a value in the range [0, 2^63 - 1] to match Python's random.randint(0, 2**63 - 1)
rand::random::<u64>() & (i64::MAX as u64),
),
);
}
Ok(())
}
}
// Request types
#[derive(Debug, Deserialize, Serialize)]
pub struct GenerateReqInput {
pub text: Option<InputText>,
pub input_ids: Option<InputIds>,
#[serde(default)]
pub stream: bool,
pub bootstrap_host: Option<BootstrapHost>,
pub bootstrap_port: Option<BootstrapPort>,
pub bootstrap_room: Option<BootstrapRoom>,
#[serde(flatten)]
pub other: Value,
}
impl GenerateReqInput {
pub fn get_batch_size(&self) -> Result<Option<usize>, String> {
if self.text.is_some() && self.input_ids.is_some() {
return Err("Both text and input_ids are present in the request".to_string());
}
// Check text batch
if let Some(InputText::Batch(texts)) = &self.text {
if texts.is_empty() {
return Err("Batch text array is empty".to_string());
}
return Ok(Some(texts.len()));
}
// Check input_ids batch
if let Some(InputIds::Batch(ids)) = &self.input_ids {
if ids.is_empty() {
return Err("Batch input_ids array is empty".to_string());
}
// Validate each sequence is not empty
for (i, seq) in ids.iter().enumerate() {
if seq.is_empty() {
return Err(format!("Input sequence at index {} is empty", i));
}
}
return Ok(Some(ids.len()));
}
Ok(None)
}
}
impl Bootstrap for GenerateReqInput {
fn is_stream(&self) -> bool {
self.stream
}
fn get_batch_size(&self) -> Result<Option<usize>, String> {
self.get_batch_size()
}
fn set_bootstrap_info(
&mut self,
bootstrap_host: BootstrapHost,
bootstrap_port: BootstrapPort,
bootstrap_room: BootstrapRoom,
) {
self.bootstrap_host = Some(bootstrap_host);
self.bootstrap_port = Some(bootstrap_port);
self.bootstrap_room = Some(bootstrap_room);
}
}
#[derive(Debug, Deserialize, Serialize)]
pub struct ChatReqInput {
#[serde(default)]
pub stream: bool,
pub bootstrap_host: Option<BootstrapHost>,
pub bootstrap_port: Option<BootstrapPort>,
pub bootstrap_room: Option<BootstrapRoom>,
#[serde(flatten)]
pub other: Value,
}
impl Bootstrap for ChatReqInput {
fn is_stream(&self) -> bool {
self.stream
}
fn get_batch_size(&self) -> Result<Option<usize>, String> {
// Check if 'n' parameter is present and > 1
if let Some(n_value) = self.other.get("n") {
if let Some(n) = n_value.as_u64() {
if n > 1 {
return Ok(Some(n as usize));
}
}
}
Ok(None)
}
fn set_bootstrap_info(
&mut self,
bootstrap_host: BootstrapHost,
bootstrap_port: BootstrapPort,
bootstrap_room: BootstrapRoom,
) {
self.bootstrap_host = Some(bootstrap_host);
self.bootstrap_port = Some(bootstrap_port);
self.bootstrap_room = Some(bootstrap_room);
}
}
// Bootstrap implementation for CompletionRequest to preserve OpenAI format
impl Bootstrap for CompletionRequest {
fn is_stream(&self) -> bool {
self.stream
}
fn get_batch_size(&self) -> Result<Option<usize>, String> {
if let StringOrArray::Array(prompts) = &self.prompt {
if prompts.is_empty() {
return Err("Batch prompt array is empty".to_string());
}
return Ok(Some(prompts.len()));
}
// Single string prompt
Ok(None)
}
fn set_bootstrap_info(
&mut self,
bootstrap_host: BootstrapHost,
bootstrap_port: BootstrapPort,
bootstrap_room: BootstrapRoom,
) {
// Insert bootstrap_host - it serializes correctly whether Single or Batch
if let Ok(host_value) = serde_json::to_value(&bootstrap_host) {
self.other.insert("bootstrap_host".to_string(), host_value);
}
// Insert bootstrap_port - it serializes correctly whether Single or Batch
if let Ok(port_value) = serde_json::to_value(&bootstrap_port) {
self.other.insert("bootstrap_port".to_string(), port_value);
}
// Insert bootstrap_room - it serializes correctly whether Single or Batch
if let Ok(room_value) = serde_json::to_value(&bootstrap_room) {
self.other.insert("bootstrap_room".to_string(), room_value);
}
}
}
#[cfg(test)]
mod bootstrap_tests {
use super::*;
use crate::core::BasicWorker;
use crate::openai_api_types::StringOrArray;
/// Create a default CompletionRequest for testing with minimal fields set
fn default_completion_request() -> CompletionRequest {
CompletionRequest {
model: String::new(),
prompt: StringOrArray::String(String::new()),
n: None,
other: serde_json::Map::new(),
suffix: None,
max_tokens: None,
temperature: None,
top_p: None,
stream: false,
stream_options: None,
logprobs: None,
echo: false,
stop: None,
presence_penalty: None,
frequency_penalty: None,
best_of: None,
logit_bias: None,
user: None,
seed: None,
// SGLang Extensions
top_k: None,
min_p: None,
min_tokens: None,
repetition_penalty: None,
regex: None,
ebnf: None,
json_schema: None,
stop_token_ids: None,
no_stop_trim: false,
ignore_eos: false,
skip_special_tokens: true,
// SGLang Extensions
lora_path: None,
session_params: None,
return_hidden_states: false,
}
}
#[test]
fn test_completion_batch_size_with_array_prompt() {
let req = CompletionRequest {
model: "test".to_string(),
prompt: StringOrArray::Array(vec!["prompt1".to_string(), "prompt2".to_string()]),
..default_completion_request()
};
// Should return batch size for array prompt
assert_eq!(req.get_batch_size().unwrap(), Some(2));
}
#[test]
fn test_completion_batch_size_with_single_prompt() {
let req = CompletionRequest {
model: "test".to_string(),
prompt: StringOrArray::String("single prompt".to_string()),
..default_completion_request()
};
// Should return None for single prompt
assert_eq!(req.get_batch_size().unwrap(), None);
}
#[test]
fn test_completion_batch_size_with_n_parameter() {
let req = CompletionRequest {
model: "test".to_string(),
prompt: StringOrArray::String("single prompt".to_string()),
n: Some(3),
..default_completion_request()
};
// Should return None for single string prompt, even with n > 1
// SGLang handles n parameter differently than batch requests
assert_eq!(req.get_batch_size().unwrap(), None);
}
#[test]
fn test_completion_bootstrap_single_values() {
let mut req = CompletionRequest {
model: "test".to_string(),
prompt: StringOrArray::Array(vec!["prompt1".to_string(), "prompt2".to_string()]),
..default_completion_request()
};
// Set bootstrap info - should always use single values
req.set_bootstrap_info(
BootstrapHost::Single("test-server".to_string()),
BootstrapPort::Single(Some(5678)),
BootstrapRoom::Single(12345),
);
// Verify single values were created
assert!(req.other.get("bootstrap_host").unwrap().is_string());
assert!(req.other.get("bootstrap_port").unwrap().is_number());
assert!(req.other.get("bootstrap_room").unwrap().is_number());
assert_eq!(
req.other.get("bootstrap_host").unwrap().as_str().unwrap(),
"test-server"
);
assert_eq!(
req.other.get("bootstrap_port").unwrap().as_u64().unwrap(),
5678
);
assert_eq!(
req.other.get("bootstrap_room").unwrap().as_u64().unwrap(),
12345
);
}
#[test]
fn test_completion_bootstrap_array_values() {
let mut req = CompletionRequest {
model: "test".to_string(),
prompt: StringOrArray::Array(vec!["prompt1".to_string(), "prompt2".to_string()]),
..default_completion_request()
};
// Set bootstrap info with arrays
req.set_bootstrap_info(
BootstrapHost::Batch(vec!["test-server".to_string(); 2]),
BootstrapPort::Batch(vec![Some(5678); 2]),
BootstrapRoom::Batch(vec![12345, 67890]),
);
// Verify arrays were created correctly
assert!(req.other.get("bootstrap_host").unwrap().is_array());
assert!(req.other.get("bootstrap_port").unwrap().is_array());
assert!(req.other.get("bootstrap_room").unwrap().is_array());
let hosts = req.other.get("bootstrap_host").unwrap().as_array().unwrap();
assert_eq!(hosts.len(), 2);
assert_eq!(hosts[0].as_str().unwrap(), "test-server");
let ports = req.other.get("bootstrap_port").unwrap().as_array().unwrap();
assert_eq!(ports.len(), 2);
assert_eq!(ports[0].as_u64().unwrap(), 5678);
let rooms = req.other.get("bootstrap_room").unwrap().as_array().unwrap();
assert_eq!(rooms.len(), 2);
assert_eq!(rooms[0].as_u64().unwrap(), 12345);
assert_eq!(rooms[1].as_u64().unwrap(), 67890);
}
#[test]
fn test_bootstrap_room_range() {
// Test that bootstrap_room values are within the expected range [0, 2^63 - 1]
let worker = BasicWorker::new(
"http://test:8000".to_string(),
WorkerType::Prefill {
bootstrap_port: Some(8080),
},
);
// Test single request
let mut single_req = GenerateReqInput {
text: Some(InputText::Single("test".to_string())),
input_ids: None,
stream: false,
bootstrap_host: None,
bootstrap_port: None,
bootstrap_room: None,
other: Value::Object(serde_json::Map::new()),
};
for _ in 0..200000 {
single_req.add_bootstrap_info(&worker).unwrap();
if let Some(BootstrapRoom::Single(room)) = single_req.bootstrap_room {
// Verify the room value is within signed 64-bit range
assert!(room <= i64::MAX as u64, "Room {} exceeds i64::MAX", room);
} else {
panic!("Expected single bootstrap room");
}
}
// Test batch request
let mut batch_req = GenerateReqInput {
text: Some(InputText::Batch(vec![
"test1".to_string(),
"test2".to_string(),
])),
input_ids: None,
stream: false,
bootstrap_host: None,
bootstrap_port: None,
bootstrap_room: None,
other: Value::Object(serde_json::Map::new()),
};
for _ in 0..200000 {
batch_req.add_bootstrap_info(&worker).unwrap();
if let Some(BootstrapRoom::Batch(rooms)) = &batch_req.bootstrap_room {
for room in rooms {
// Verify each room value is within signed 64-bit range
assert!(*room <= i64::MAX as u64, "Room {} exceeds i64::MAX", room);
}
} else {
panic!("Expected batch bootstrap rooms");
}
}
}
}

File diff suppressed because it is too large Load Diff