[router] fix req handling order, improve serialization, remove retry (#8888)
This commit is contained in:
@@ -2,12 +2,12 @@ use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criteri
|
||||
use serde_json::{from_str, to_string, to_value, to_vec};
|
||||
use std::time::Instant;
|
||||
|
||||
use sglang_router_rs::core::{BasicWorker, WorkerType};
|
||||
use sglang_router_rs::core::{BasicWorker, Worker, WorkerType};
|
||||
use sglang_router_rs::openai_api_types::{
|
||||
ChatCompletionRequest, ChatMessage, CompletionRequest, GenerateParameters, GenerateRequest,
|
||||
SamplingParams, StringOrArray, UserMessageContent,
|
||||
};
|
||||
use sglang_router_rs::routers::bootstrap_injector::inject_bootstrap_fields;
|
||||
use sglang_router_rs::routers::pd_types::{generate_room_id, get_hostname, RequestWithBootstrap};
|
||||
|
||||
fn create_test_worker() -> BasicWorker {
|
||||
BasicWorker::new(
|
||||
@@ -18,6 +18,16 @@ fn create_test_worker() -> BasicWorker {
|
||||
)
|
||||
}
|
||||
|
||||
// Helper function to get bootstrap info from worker
|
||||
fn get_bootstrap_info(worker: &BasicWorker) -> (String, Option<u16>) {
|
||||
let hostname = get_hostname(worker.url());
|
||||
let bootstrap_port = match worker.worker_type() {
|
||||
WorkerType::Prefill { bootstrap_port } => bootstrap_port.clone(),
|
||||
_ => None,
|
||||
};
|
||||
(hostname, bootstrap_port)
|
||||
}
|
||||
|
||||
/// Create a default GenerateRequest for benchmarks with minimal fields set
|
||||
fn default_generate_request() -> GenerateRequest {
|
||||
GenerateRequest {
|
||||
@@ -331,35 +341,56 @@ fn bench_bootstrap_injection(c: &mut Criterion) {
|
||||
let completion_req = create_sample_completion_request();
|
||||
let large_chat_req = create_large_chat_completion_request();
|
||||
let worker = create_test_worker();
|
||||
let (hostname, bootstrap_port) = get_bootstrap_info(&worker);
|
||||
|
||||
group.bench_function("generate_bootstrap_injection", |b| {
|
||||
b.iter(|| {
|
||||
let mut json = to_value(black_box(&generate_req)).unwrap();
|
||||
inject_bootstrap_fields(&mut json, &worker).unwrap();
|
||||
let request_with_bootstrap = RequestWithBootstrap {
|
||||
original: &generate_req,
|
||||
bootstrap_host: hostname.clone(),
|
||||
bootstrap_port,
|
||||
bootstrap_room: generate_room_id(),
|
||||
};
|
||||
let json = to_value(black_box(&request_with_bootstrap)).unwrap();
|
||||
black_box(json);
|
||||
});
|
||||
});
|
||||
|
||||
group.bench_function("chat_completion_bootstrap_injection", |b| {
|
||||
b.iter(|| {
|
||||
let mut json = to_value(black_box(&chat_req)).unwrap();
|
||||
inject_bootstrap_fields(&mut json, &worker).unwrap();
|
||||
let request_with_bootstrap = RequestWithBootstrap {
|
||||
original: &chat_req,
|
||||
bootstrap_host: hostname.clone(),
|
||||
bootstrap_port,
|
||||
bootstrap_room: generate_room_id(),
|
||||
};
|
||||
let json = to_value(black_box(&request_with_bootstrap)).unwrap();
|
||||
black_box(json);
|
||||
});
|
||||
});
|
||||
|
||||
group.bench_function("completion_bootstrap_injection", |b| {
|
||||
b.iter(|| {
|
||||
let mut json = to_value(black_box(&completion_req)).unwrap();
|
||||
inject_bootstrap_fields(&mut json, &worker).unwrap();
|
||||
let request_with_bootstrap = RequestWithBootstrap {
|
||||
original: &completion_req,
|
||||
bootstrap_host: hostname.clone(),
|
||||
bootstrap_port,
|
||||
bootstrap_room: generate_room_id(),
|
||||
};
|
||||
let json = to_value(black_box(&request_with_bootstrap)).unwrap();
|
||||
black_box(json);
|
||||
});
|
||||
});
|
||||
|
||||
group.bench_function("large_chat_completion_bootstrap_injection", |b| {
|
||||
b.iter(|| {
|
||||
let mut json = to_value(black_box(&large_chat_req)).unwrap();
|
||||
inject_bootstrap_fields(&mut json, &worker).unwrap();
|
||||
let request_with_bootstrap = RequestWithBootstrap {
|
||||
original: &large_chat_req,
|
||||
bootstrap_host: hostname.clone(),
|
||||
bootstrap_port,
|
||||
bootstrap_room: generate_room_id(),
|
||||
};
|
||||
let json = to_value(black_box(&request_with_bootstrap)).unwrap();
|
||||
black_box(json);
|
||||
});
|
||||
});
|
||||
@@ -441,6 +472,7 @@ fn bench_throughput_by_size(c: &mut Criterion) {
|
||||
};
|
||||
|
||||
let worker = create_test_worker();
|
||||
let (hostname, bootstrap_port) = get_bootstrap_info(&worker);
|
||||
|
||||
for (name, req) in [
|
||||
("small", &small_generate),
|
||||
@@ -449,6 +481,7 @@ fn bench_throughput_by_size(c: &mut Criterion) {
|
||||
] {
|
||||
let json = to_string(req).unwrap();
|
||||
let size_bytes = json.len();
|
||||
let hostname_clone = hostname.clone();
|
||||
|
||||
group.throughput(Throughput::Bytes(size_bytes as u64));
|
||||
group.bench_with_input(BenchmarkId::new("serialize", name), &req, |b, req| {
|
||||
@@ -472,10 +505,16 @@ fn bench_throughput_by_size(c: &mut Criterion) {
|
||||
group.bench_with_input(
|
||||
BenchmarkId::new("bootstrap_inject", name),
|
||||
&req,
|
||||
|b, req| {
|
||||
move |b, req| {
|
||||
let hostname = hostname_clone.clone();
|
||||
b.iter(|| {
|
||||
let mut json = to_value(req).unwrap();
|
||||
inject_bootstrap_fields(&mut json, &worker).unwrap();
|
||||
let request_with_bootstrap = RequestWithBootstrap {
|
||||
original: req,
|
||||
bootstrap_host: hostname.clone(),
|
||||
bootstrap_port,
|
||||
bootstrap_room: generate_room_id(),
|
||||
};
|
||||
let json = to_value(&request_with_bootstrap).unwrap();
|
||||
black_box(json);
|
||||
});
|
||||
},
|
||||
@@ -493,17 +532,21 @@ fn bench_full_round_trip(c: &mut Criterion) {
|
||||
let chat_json = to_string(&create_sample_chat_completion_request()).unwrap();
|
||||
let completion_json = to_string(&create_sample_completion_request()).unwrap();
|
||||
let worker = create_test_worker();
|
||||
let (hostname, bootstrap_port) = get_bootstrap_info(&worker);
|
||||
|
||||
group.bench_function("generate_openai_to_pd_pipeline", |b| {
|
||||
b.iter(|| {
|
||||
// Deserialize OpenAI request
|
||||
let req: GenerateRequest = from_str(black_box(&generate_json)).unwrap();
|
||||
// Convert to JSON Value
|
||||
let mut json = to_value(&req).unwrap();
|
||||
// Inject bootstrap fields
|
||||
inject_bootstrap_fields(&mut json, &worker).unwrap();
|
||||
// Create wrapper with bootstrap fields
|
||||
let request_with_bootstrap = RequestWithBootstrap {
|
||||
original: &req,
|
||||
bootstrap_host: hostname.clone(),
|
||||
bootstrap_port,
|
||||
bootstrap_room: generate_room_id(),
|
||||
};
|
||||
// Serialize final request
|
||||
let pd_json = to_string(&json).unwrap();
|
||||
let pd_json = to_string(&request_with_bootstrap).unwrap();
|
||||
black_box(pd_json);
|
||||
});
|
||||
});
|
||||
@@ -511,9 +554,13 @@ fn bench_full_round_trip(c: &mut Criterion) {
|
||||
group.bench_function("chat_completion_openai_to_pd_pipeline", |b| {
|
||||
b.iter(|| {
|
||||
let req: ChatCompletionRequest = from_str(black_box(&chat_json)).unwrap();
|
||||
let mut json = to_value(&req).unwrap();
|
||||
inject_bootstrap_fields(&mut json, &worker).unwrap();
|
||||
let pd_json = to_string(&json).unwrap();
|
||||
let request_with_bootstrap = RequestWithBootstrap {
|
||||
original: &req,
|
||||
bootstrap_host: hostname.clone(),
|
||||
bootstrap_port,
|
||||
bootstrap_room: generate_room_id(),
|
||||
};
|
||||
let pd_json = to_string(&request_with_bootstrap).unwrap();
|
||||
black_box(pd_json);
|
||||
});
|
||||
});
|
||||
@@ -521,9 +568,13 @@ fn bench_full_round_trip(c: &mut Criterion) {
|
||||
group.bench_function("completion_openai_to_pd_pipeline", |b| {
|
||||
b.iter(|| {
|
||||
let req: CompletionRequest = from_str(black_box(&completion_json)).unwrap();
|
||||
let mut json = to_value(&req).unwrap();
|
||||
inject_bootstrap_fields(&mut json, &worker).unwrap();
|
||||
let pd_json = to_string(&json).unwrap();
|
||||
let request_with_bootstrap = RequestWithBootstrap {
|
||||
original: &req,
|
||||
bootstrap_host: hostname.clone(),
|
||||
bootstrap_port,
|
||||
bootstrap_room: generate_room_id(),
|
||||
};
|
||||
let pd_json = to_string(&request_with_bootstrap).unwrap();
|
||||
black_box(pd_json);
|
||||
});
|
||||
});
|
||||
@@ -575,10 +626,16 @@ fn benchmark_summary(c: &mut Criterion) {
|
||||
);
|
||||
|
||||
// Measure bootstrap injection (replaces adaptation)
|
||||
let (hostname, bootstrap_port) = get_bootstrap_info(&worker);
|
||||
let start = Instant::now();
|
||||
for _ in 0..1000 {
|
||||
let mut json = to_value(&generate_req).unwrap();
|
||||
let _ = black_box(inject_bootstrap_fields(&mut json, &worker));
|
||||
let request_with_bootstrap = RequestWithBootstrap {
|
||||
original: &generate_req,
|
||||
bootstrap_host: hostname.clone(),
|
||||
bootstrap_port,
|
||||
bootstrap_room: generate_room_id(),
|
||||
};
|
||||
let _ = black_box(to_value(&request_with_bootstrap).unwrap());
|
||||
}
|
||||
let inject_time = start.elapsed().as_nanos() / 1000;
|
||||
println!(" * Bootstrap Injection (avg): {:>6} ns/req", inject_time);
|
||||
|
||||
@@ -121,6 +121,8 @@ class BenchmarkRunner:
|
||||
results["serialization_time"] = self._extract_time(line)
|
||||
elif "Deserialization (avg):" in line:
|
||||
results["deserialization_time"] = self._extract_time(line)
|
||||
elif "Bootstrap Injection (avg):" in line:
|
||||
results["bootstrap_injection_time"] = self._extract_time(line)
|
||||
elif "Total Pipeline (avg):" in line:
|
||||
results["total_time"] = self._extract_time(line)
|
||||
|
||||
@@ -143,6 +145,7 @@ class BenchmarkRunner:
|
||||
thresholds = {
|
||||
"serialization_time": 2000, # 2μs max
|
||||
"deserialization_time": 2000, # 2μs max
|
||||
"bootstrap_injection_time": 5000, # 5μs max
|
||||
"total_time": 10000, # 10μs max
|
||||
}
|
||||
|
||||
|
||||
@@ -230,6 +230,10 @@ impl LoadBalancingPolicy for CacheAwarePolicy {
|
||||
"cache_aware"
|
||||
}
|
||||
|
||||
fn needs_request_text(&self) -> bool {
|
||||
true // Cache-aware policy needs request text for cache affinity
|
||||
}
|
||||
|
||||
fn on_request_complete(&self, worker_url: &str, success: bool) {
|
||||
// Could track success rates per worker for more intelligent routing
|
||||
if !success {
|
||||
|
||||
@@ -59,6 +59,11 @@ pub trait LoadBalancingPolicy: Send + Sync + Debug {
|
||||
/// Get policy name for metrics and debugging
|
||||
fn name(&self) -> &'static str;
|
||||
|
||||
/// Check if this policy needs request text for routing decisions
|
||||
fn needs_request_text(&self) -> bool {
|
||||
false // Default: most policies don't need request text
|
||||
}
|
||||
|
||||
/// Update worker load information
|
||||
///
|
||||
/// This is called periodically with current load information for load-aware policies.
|
||||
|
||||
@@ -1,334 +0,0 @@
|
||||
// Bootstrap field injection for PD routing
|
||||
// Directly injects bootstrap fields into JSON requests without intermediate type conversions
|
||||
|
||||
use crate::core::{Worker, WorkerType};
|
||||
use crate::routers::pd_types::get_hostname;
|
||||
use serde_json::{json, Value};
|
||||
|
||||
/// Inject bootstrap fields directly into a JSON request
|
||||
/// This replaces the complex ToPdRequest -> Bootstrap trait pattern
|
||||
pub fn inject_bootstrap_fields(json: &mut Value, worker: &dyn Worker) -> Result<(), String> {
|
||||
let batch_size = extract_batch_size(json)?;
|
||||
|
||||
// Extract bootstrap port from prefill worker if it's a prefill type
|
||||
let bootstrap_port = match worker.worker_type() {
|
||||
WorkerType::Prefill { bootstrap_port } => bootstrap_port,
|
||||
_ => None,
|
||||
};
|
||||
|
||||
let hostname = get_hostname(worker.url());
|
||||
|
||||
if let Some(batch_size) = batch_size {
|
||||
// Batch scenario - create arrays of bootstrap values
|
||||
json["bootstrap_host"] = json!(vec![hostname; batch_size]);
|
||||
json["bootstrap_port"] = json!(vec![bootstrap_port; batch_size]);
|
||||
json["bootstrap_room"] = json!((0..batch_size)
|
||||
.map(|_| {
|
||||
// Generate a value in the range [0, 2^63 - 1] to match Python's random.randint(0, 2**63 - 1)
|
||||
rand::random::<u64>() & (i64::MAX as u64)
|
||||
})
|
||||
.collect::<Vec<_>>());
|
||||
} else {
|
||||
// Single scenario - create single bootstrap values
|
||||
json["bootstrap_host"] = json!(hostname);
|
||||
json["bootstrap_port"] = json!(bootstrap_port);
|
||||
json["bootstrap_room"] = json!(rand::random::<u64>() & (i64::MAX as u64));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Extract batch size from various JSON request formats
|
||||
/// Handles chat completions, completions, and generate requests
|
||||
fn extract_batch_size(json: &Value) -> Result<Option<usize>, String> {
|
||||
// Check for chat completions 'n' parameter (number of choices)
|
||||
if let Some(n) = json.get("n").and_then(|v| v.as_u64()) {
|
||||
if n > 1 {
|
||||
return Ok(Some(n as usize));
|
||||
}
|
||||
}
|
||||
|
||||
// Check for array prompts (completions API)
|
||||
if let Some(prompt) = json.get("prompt") {
|
||||
if let Some(arr) = prompt.as_array() {
|
||||
if arr.is_empty() {
|
||||
return Err("Batch prompt array is empty".to_string());
|
||||
}
|
||||
return Ok(Some(arr.len()));
|
||||
}
|
||||
}
|
||||
|
||||
// Check for array texts (generate API)
|
||||
if let Some(text) = json.get("text") {
|
||||
if let Some(arr) = text.as_array() {
|
||||
if arr.is_empty() {
|
||||
return Err("Batch text array is empty".to_string());
|
||||
}
|
||||
return Ok(Some(arr.len()));
|
||||
}
|
||||
}
|
||||
|
||||
// Check for batch input_ids (generate API)
|
||||
if let Some(input_ids) = json.get("input_ids") {
|
||||
if let Some(arr) = input_ids.as_array() {
|
||||
if arr.is_empty() {
|
||||
return Err("Batch input_ids array is empty".to_string());
|
||||
}
|
||||
return Ok(Some(arr.len()));
|
||||
}
|
||||
}
|
||||
|
||||
// No batch indicators found - single request
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::core::BasicWorker;
|
||||
use serde_json::json;
|
||||
|
||||
fn create_test_worker() -> BasicWorker {
|
||||
BasicWorker::new(
|
||||
"http://test-server:8000".to_string(),
|
||||
WorkerType::Prefill {
|
||||
bootstrap_port: Some(5678),
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_inject_bootstrap_single_request() {
|
||||
let worker = create_test_worker();
|
||||
let mut json = json!({
|
||||
"model": "test-model",
|
||||
"prompt": "Hello world",
|
||||
"max_tokens": 100
|
||||
});
|
||||
|
||||
let result = inject_bootstrap_fields(&mut json, &worker);
|
||||
assert!(result.is_ok());
|
||||
|
||||
// Verify bootstrap fields were added
|
||||
assert_eq!(json["bootstrap_host"], json!("test-server"));
|
||||
assert_eq!(json["bootstrap_port"], json!(5678));
|
||||
assert!(json["bootstrap_room"].is_number());
|
||||
|
||||
// Verify original fields preserved
|
||||
assert_eq!(json["model"], json!("test-model"));
|
||||
assert_eq!(json["prompt"], json!("Hello world"));
|
||||
assert_eq!(json["max_tokens"], json!(100));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_inject_bootstrap_batch_prompt() {
|
||||
let worker = create_test_worker();
|
||||
let mut json = json!({
|
||||
"model": "test-model",
|
||||
"prompt": ["Hello", "World"],
|
||||
"max_tokens": 100
|
||||
});
|
||||
|
||||
let result = inject_bootstrap_fields(&mut json, &worker);
|
||||
assert!(result.is_ok());
|
||||
|
||||
// Verify batch bootstrap fields
|
||||
assert_eq!(
|
||||
json["bootstrap_host"],
|
||||
json!(["test-server", "test-server"])
|
||||
);
|
||||
assert_eq!(json["bootstrap_port"], json!([5678, 5678]));
|
||||
|
||||
let bootstrap_rooms = json["bootstrap_room"].as_array().unwrap();
|
||||
assert_eq!(bootstrap_rooms.len(), 2);
|
||||
for room in bootstrap_rooms {
|
||||
assert!(room.is_number());
|
||||
let room_val = room.as_u64().unwrap();
|
||||
assert!(room_val <= i64::MAX as u64);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_inject_bootstrap_chat_n_parameter() {
|
||||
let worker = create_test_worker();
|
||||
let mut json = json!({
|
||||
"model": "gpt-4",
|
||||
"messages": [{"role": "user", "content": "Hello"}],
|
||||
"n": 3
|
||||
});
|
||||
|
||||
let result = inject_bootstrap_fields(&mut json, &worker);
|
||||
assert!(result.is_ok());
|
||||
|
||||
// Verify batch bootstrap fields for n=3
|
||||
let bootstrap_hosts = json["bootstrap_host"].as_array().unwrap();
|
||||
assert_eq!(bootstrap_hosts.len(), 3);
|
||||
assert_eq!(bootstrap_hosts[0], json!("test-server"));
|
||||
|
||||
let bootstrap_ports = json["bootstrap_port"].as_array().unwrap();
|
||||
assert_eq!(bootstrap_ports.len(), 3);
|
||||
assert_eq!(bootstrap_ports[0], json!(5678));
|
||||
|
||||
let bootstrap_rooms = json["bootstrap_room"].as_array().unwrap();
|
||||
assert_eq!(bootstrap_rooms.len(), 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_inject_bootstrap_generate_text_array() {
|
||||
let worker = create_test_worker();
|
||||
let mut json = json!({
|
||||
"text": ["First prompt", "Second prompt"],
|
||||
"stream": false
|
||||
});
|
||||
|
||||
let result = inject_bootstrap_fields(&mut json, &worker);
|
||||
assert!(result.is_ok());
|
||||
|
||||
// Verify batch bootstrap fields
|
||||
let bootstrap_hosts = json["bootstrap_host"].as_array().unwrap();
|
||||
assert_eq!(bootstrap_hosts.len(), 2);
|
||||
|
||||
let bootstrap_rooms = json["bootstrap_room"].as_array().unwrap();
|
||||
assert_eq!(bootstrap_rooms.len(), 2);
|
||||
// Ensure room values are different (randomness)
|
||||
assert_ne!(bootstrap_rooms[0], bootstrap_rooms[1]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_inject_bootstrap_input_ids_array() {
|
||||
let worker = create_test_worker();
|
||||
let mut json = json!({
|
||||
"input_ids": [[1, 2, 3], [4, 5, 6]],
|
||||
"stream": false
|
||||
});
|
||||
|
||||
let result = inject_bootstrap_fields(&mut json, &worker);
|
||||
assert!(result.is_ok());
|
||||
|
||||
// Verify batch bootstrap fields
|
||||
let bootstrap_hosts = json["bootstrap_host"].as_array().unwrap();
|
||||
assert_eq!(bootstrap_hosts.len(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extract_batch_size_empty_array_error() {
|
||||
let json = json!({
|
||||
"prompt": [],
|
||||
"model": "test"
|
||||
});
|
||||
|
||||
let result = extract_batch_size(&json);
|
||||
assert!(result.is_err());
|
||||
assert!(result.unwrap_err().contains("empty"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extract_batch_size_single_requests() {
|
||||
// Single string prompt
|
||||
let json = json!({
|
||||
"prompt": "Hello world",
|
||||
"model": "test"
|
||||
});
|
||||
assert_eq!(extract_batch_size(&json).unwrap(), None);
|
||||
|
||||
// Single text
|
||||
let json = json!({
|
||||
"text": "Hello world",
|
||||
"stream": false
|
||||
});
|
||||
assert_eq!(extract_batch_size(&json).unwrap(), None);
|
||||
|
||||
// Chat with n=1 (default)
|
||||
let json = json!({
|
||||
"messages": [{"role": "user", "content": "Hello"}],
|
||||
"n": 1
|
||||
});
|
||||
assert_eq!(extract_batch_size(&json).unwrap(), None);
|
||||
|
||||
// Chat without n parameter
|
||||
let json = json!({
|
||||
"messages": [{"role": "user", "content": "Hello"}]
|
||||
});
|
||||
assert_eq!(extract_batch_size(&json).unwrap(), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_inject_bootstrap_preserves_sglang_fields() {
|
||||
let worker = create_test_worker();
|
||||
let mut json = json!({
|
||||
"model": "test-model",
|
||||
"prompt": "Hello",
|
||||
// SGLang extensions should be preserved
|
||||
"top_k": 40,
|
||||
"min_p": 0.05,
|
||||
"repetition_penalty": 1.1,
|
||||
"regex": "test_pattern",
|
||||
"lora_path": "test.bin",
|
||||
"no_stop_trim": true,
|
||||
"ignore_eos": false
|
||||
});
|
||||
|
||||
let result = inject_bootstrap_fields(&mut json, &worker);
|
||||
assert!(result.is_ok());
|
||||
|
||||
// Verify bootstrap fields added
|
||||
assert!(json.get("bootstrap_host").is_some());
|
||||
assert!(json.get("bootstrap_port").is_some());
|
||||
assert!(json.get("bootstrap_room").is_some());
|
||||
|
||||
// Verify all SGLang fields preserved
|
||||
assert_eq!(json["top_k"], json!(40));
|
||||
assert_eq!(json["min_p"], json!(0.05));
|
||||
assert_eq!(json["repetition_penalty"], json!(1.1));
|
||||
assert_eq!(json["regex"], json!("test_pattern"));
|
||||
assert_eq!(json["lora_path"], json!("test.bin"));
|
||||
assert_eq!(json["no_stop_trim"], json!(true));
|
||||
assert_eq!(json["ignore_eos"], json!(false));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_bootstrap_room_range() {
|
||||
let worker = create_test_worker();
|
||||
|
||||
// Test single request room generation
|
||||
for _ in 0..1000 {
|
||||
let mut json = json!({"prompt": "test"});
|
||||
inject_bootstrap_fields(&mut json, &worker).unwrap();
|
||||
|
||||
let room = json["bootstrap_room"].as_u64().unwrap();
|
||||
assert!(room <= i64::MAX as u64, "Room {} exceeds i64::MAX", room);
|
||||
}
|
||||
|
||||
// Test batch request room generation
|
||||
for _ in 0..100 {
|
||||
let mut json = json!({"prompt": ["test1", "test2"]});
|
||||
inject_bootstrap_fields(&mut json, &worker).unwrap();
|
||||
|
||||
let rooms = json["bootstrap_room"].as_array().unwrap();
|
||||
for room_val in rooms {
|
||||
let room = room_val.as_u64().unwrap();
|
||||
assert!(room <= i64::MAX as u64, "Room {} exceeds i64::MAX", room);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_worker_without_bootstrap_port() {
|
||||
let worker = BasicWorker::new(
|
||||
"http://decode-only:8000".to_string(),
|
||||
WorkerType::Decode, // No bootstrap port
|
||||
);
|
||||
|
||||
let mut json = json!({
|
||||
"prompt": "Hello world"
|
||||
});
|
||||
|
||||
let result = inject_bootstrap_fields(&mut json, &worker);
|
||||
assert!(result.is_ok());
|
||||
|
||||
// Verify bootstrap fields with null port
|
||||
assert_eq!(json["bootstrap_host"], json!("decode-only"));
|
||||
assert_eq!(json["bootstrap_port"], json!(null));
|
||||
assert!(json["bootstrap_room"].is_number());
|
||||
}
|
||||
}
|
||||
@@ -11,7 +11,6 @@ use std::fmt::Debug;
|
||||
|
||||
use crate::openai_api_types::{ChatCompletionRequest, CompletionRequest, GenerateRequest};
|
||||
|
||||
pub mod bootstrap_injector;
|
||||
pub mod factory;
|
||||
pub mod pd_router;
|
||||
pub mod pd_types;
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
// PD (Prefill-Decode) Router Implementation
|
||||
// This module handles routing for disaggregated prefill-decode systems
|
||||
|
||||
use super::bootstrap_injector::inject_bootstrap_fields;
|
||||
use super::pd_types::{api_path, PDRouterError};
|
||||
use crate::config::types::RetryConfig;
|
||||
use crate::core::{HealthChecker, Worker, WorkerFactory, WorkerLoadGuard};
|
||||
@@ -19,7 +17,6 @@ use axum::{
|
||||
Json,
|
||||
};
|
||||
use futures_util::StreamExt;
|
||||
use rand::Rng;
|
||||
use reqwest::Client;
|
||||
use serde_json::Value;
|
||||
use std::collections::HashMap;
|
||||
@@ -316,17 +313,6 @@ impl PDRouter {
|
||||
.into_response()
|
||||
}
|
||||
|
||||
// Helper to handle bootstrap injection errors
|
||||
fn handle_bootstrap_error(error: impl std::fmt::Display) -> Response {
|
||||
error!("Failed to add bootstrap info error={}", error);
|
||||
RouterMetrics::record_pd_error("bootstrap_injection");
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Bootstrap injection failed: {}", error),
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
|
||||
// Helper to handle serialization errors
|
||||
fn handle_serialization_error(error: impl std::fmt::Display) -> Response {
|
||||
error!("Failed to serialize request error={}", error);
|
||||
@@ -337,110 +323,87 @@ impl PDRouter {
|
||||
.into_response()
|
||||
}
|
||||
|
||||
// Execute the dual dispatch to prefill and decode servers with retry logic
|
||||
async fn execute_dual_dispatch(
|
||||
&self,
|
||||
headers: Option<&HeaderMap>,
|
||||
json_request: Value,
|
||||
route: &str,
|
||||
prefill: &dyn Worker,
|
||||
decode: &dyn Worker,
|
||||
is_stream: bool,
|
||||
return_logprob: bool,
|
||||
start_time: Instant,
|
||||
) -> Response {
|
||||
for attempt in 0..self.retry_config.max_retries {
|
||||
if attempt > 0 {
|
||||
// Calculate backoff with exponential growth and jitter
|
||||
let base_backoff = self.retry_config.initial_backoff_ms as f64
|
||||
* self
|
||||
.retry_config
|
||||
.backoff_multiplier
|
||||
.powf((attempt - 1) as f32) as f64;
|
||||
let backoff_ms = base_backoff.min(self.retry_config.max_backoff_ms as f64) as u64;
|
||||
|
||||
// Add jitter to prevent thundering herd
|
||||
let jitter = {
|
||||
let mut rng = rand::thread_rng();
|
||||
rng.gen_range(0..backoff_ms / 2)
|
||||
};
|
||||
let total_backoff = Duration::from_millis(backoff_ms + jitter);
|
||||
|
||||
info!(
|
||||
"Retrying request (attempt {}/{}) after {:?} backoff",
|
||||
attempt + 1,
|
||||
self.retry_config.max_retries,
|
||||
total_backoff
|
||||
);
|
||||
|
||||
tokio::time::sleep(total_backoff).await;
|
||||
}
|
||||
|
||||
debug!(
|
||||
"Executing request attempt {}/{}",
|
||||
attempt + 1,
|
||||
self.retry_config.max_retries
|
||||
);
|
||||
let result = self
|
||||
.execute_dual_dispatch_inner(
|
||||
headers,
|
||||
json_request.clone(),
|
||||
route,
|
||||
prefill,
|
||||
decode,
|
||||
is_stream,
|
||||
return_logprob,
|
||||
start_time,
|
||||
)
|
||||
.await;
|
||||
|
||||
// Check if we should retry based on the response status
|
||||
let status = result.status();
|
||||
debug!(
|
||||
"Request attempt {} returned status: {}",
|
||||
attempt + 1,
|
||||
status
|
||||
);
|
||||
|
||||
// Don't retry client errors (4xx) or successful responses
|
||||
if status.is_client_error() || status.is_success() {
|
||||
debug!(
|
||||
"Returning response with status {} (no retry needed)",
|
||||
status
|
||||
);
|
||||
return result;
|
||||
}
|
||||
|
||||
// Check if this is the last attempt
|
||||
if attempt == self.retry_config.max_retries - 1 {
|
||||
warn!("Final attempt failed with status {}", status);
|
||||
return result;
|
||||
}
|
||||
|
||||
// Log retry decision for retryable errors
|
||||
if status.is_server_error()
|
||||
|| status == StatusCode::BAD_GATEWAY
|
||||
|| status == StatusCode::GATEWAY_TIMEOUT
|
||||
{
|
||||
warn!(
|
||||
"Retryable error status: {} on attempt {}/{}. Will retry.",
|
||||
status,
|
||||
attempt + 1,
|
||||
self.retry_config.max_retries
|
||||
);
|
||||
} else {
|
||||
// Don't retry other statuses
|
||||
debug!("Status {} is not retryable, returning response", status);
|
||||
return result;
|
||||
// Helper to determine batch size from a GenerateRequest
|
||||
fn get_generate_batch_size(req: &GenerateRequest) -> Option<usize> {
|
||||
// Check prompt array
|
||||
if let Some(prompt) = &req.prompt {
|
||||
if let crate::openai_api_types::StringOrArray::Array(arr) = prompt {
|
||||
if !arr.is_empty() {
|
||||
return Some(arr.len());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// This should never be reached due to the loop logic, but just in case
|
||||
unreachable!("Retry loop completed without returning")
|
||||
// Check text array
|
||||
if let Some(text) = &req.text {
|
||||
if text.contains("[") && text.contains("]") {
|
||||
// This is a simplified check - in reality we'd need to parse JSON
|
||||
return None; // For now, fall back to non-batch
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
// Inner implementation of dual dispatch (extracted for retry logic)
|
||||
async fn execute_dual_dispatch_inner(
|
||||
// Helper to determine batch size from a ChatCompletionRequest
|
||||
fn get_chat_batch_size(req: &ChatCompletionRequest) -> Option<usize> {
|
||||
// Check 'n' parameter for multiple responses
|
||||
if let Some(n) = req.n {
|
||||
if n > 1 {
|
||||
return Some(n as usize);
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
// Helper to determine batch size from a CompletionRequest
|
||||
fn get_completion_batch_size(req: &CompletionRequest) -> Option<usize> {
|
||||
// Check prompt array
|
||||
if let crate::openai_api_types::StringOrArray::Array(arr) = &req.prompt {
|
||||
if !arr.is_empty() {
|
||||
return Some(arr.len());
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
// Helper to create request with bootstrap fields
|
||||
fn create_request_with_bootstrap<T: serde::Serialize>(
|
||||
request: &T,
|
||||
prefill_worker: &dyn Worker,
|
||||
batch_size: Option<usize>,
|
||||
) -> Result<serde_json::Value, serde_json::Error> {
|
||||
// Get bootstrap port from prefill worker
|
||||
let bootstrap_port = match prefill_worker.worker_type() {
|
||||
crate::core::WorkerType::Prefill { bootstrap_port } => bootstrap_port,
|
||||
_ => None,
|
||||
};
|
||||
let hostname = super::pd_types::get_hostname(prefill_worker.url());
|
||||
|
||||
// Create optimized request with bootstrap fields
|
||||
if let Some(batch_size) = batch_size {
|
||||
// Batch request
|
||||
let request_with_bootstrap = super::pd_types::BatchRequestWithBootstrap {
|
||||
original: request,
|
||||
bootstrap_host: vec![hostname; batch_size],
|
||||
bootstrap_port: vec![bootstrap_port; batch_size],
|
||||
bootstrap_room: (0..batch_size)
|
||||
.map(|_| super::pd_types::generate_room_id())
|
||||
.collect(),
|
||||
};
|
||||
serde_json::to_value(&request_with_bootstrap)
|
||||
} else {
|
||||
// Single request
|
||||
let request_with_bootstrap = super::pd_types::RequestWithBootstrap {
|
||||
original: request,
|
||||
bootstrap_host: hostname,
|
||||
bootstrap_port,
|
||||
bootstrap_room: super::pd_types::generate_room_id(),
|
||||
};
|
||||
serde_json::to_value(&request_with_bootstrap)
|
||||
}
|
||||
}
|
||||
|
||||
// Execute the dual dispatch to prefill and decode servers
|
||||
async fn execute_dual_dispatch(
|
||||
&self,
|
||||
headers: Option<&HeaderMap>,
|
||||
json_request: Value,
|
||||
@@ -467,101 +430,195 @@ impl PDRouter {
|
||||
prefill.url(),
|
||||
decode.url()
|
||||
);
|
||||
let (prefill_result, decode_result) =
|
||||
tokio::join!(prefill_request.send(), decode_request.send());
|
||||
debug!("Received responses from both servers");
|
||||
|
||||
// Update metrics
|
||||
let duration = start_time.elapsed();
|
||||
RouterMetrics::record_pd_request_duration(route, duration);
|
||||
RouterMetrics::record_pd_request(route);
|
||||
RouterMetrics::record_pd_prefill_request(prefill.url());
|
||||
RouterMetrics::record_pd_decode_request(decode.url());
|
||||
if return_logprob {
|
||||
// When we need logprobs, wait for both responses
|
||||
let (prefill_result, decode_result) =
|
||||
tokio::join!(prefill_request.send(), decode_request.send());
|
||||
debug!("Received responses from both servers");
|
||||
|
||||
// Process prefill response
|
||||
let (_prefill_status, prefill_body) = match self
|
||||
.process_prefill_response(prefill_result, prefill.url(), return_logprob)
|
||||
.await
|
||||
{
|
||||
Ok(result) => result,
|
||||
Err(error_response) => return error_response,
|
||||
};
|
||||
// Update metrics
|
||||
let duration = start_time.elapsed();
|
||||
RouterMetrics::record_pd_request_duration(route, duration);
|
||||
RouterMetrics::record_pd_request(route);
|
||||
RouterMetrics::record_pd_prefill_request(prefill.url());
|
||||
RouterMetrics::record_pd_decode_request(decode.url());
|
||||
|
||||
// Process decode response
|
||||
debug!("Processing decode response");
|
||||
match decode_result {
|
||||
Ok(res) => {
|
||||
let status = StatusCode::from_u16(res.status().as_u16())
|
||||
.unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
|
||||
debug!("Decode response status: {}", status);
|
||||
// Process decode response with prefill for logprobs
|
||||
debug!("Processing decode response with logprobs");
|
||||
match decode_result {
|
||||
Ok(res) => {
|
||||
let status = StatusCode::from_u16(res.status().as_u16())
|
||||
.unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
|
||||
debug!("Decode response status: {}", status);
|
||||
|
||||
if !status.is_success() {
|
||||
RouterMetrics::record_pd_decode_error(decode.url());
|
||||
error!(
|
||||
"Decode server returned error status decode_url={} status={}",
|
||||
decode.url(),
|
||||
status
|
||||
);
|
||||
if !status.is_success() {
|
||||
RouterMetrics::record_pd_decode_error(decode.url());
|
||||
error!(
|
||||
"Decode server returned error status decode_url={} status={}",
|
||||
decode.url(),
|
||||
status
|
||||
);
|
||||
|
||||
// Return the error response from decode server
|
||||
match res.bytes().await {
|
||||
Ok(error_body) => {
|
||||
return (status, error_body).into_response();
|
||||
}
|
||||
Err(e) => {
|
||||
return (status, format!("Decode server error: {}", e)).into_response();
|
||||
// Return the error response from decode server
|
||||
match res.bytes().await {
|
||||
Ok(error_body) => {
|
||||
return (status, error_body).into_response();
|
||||
}
|
||||
Err(e) => {
|
||||
return (status, format!("Decode server error: {}", e))
|
||||
.into_response();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if is_stream {
|
||||
// Streaming response
|
||||
let prefill_logprobs = if return_logprob {
|
||||
prefill_body
|
||||
// Process prefill response for logprobs
|
||||
let prefill_body = match self
|
||||
.process_prefill_response(prefill_result, prefill.url(), return_logprob)
|
||||
.await
|
||||
{
|
||||
Ok((_, body)) => body,
|
||||
Err(error_response) => return error_response,
|
||||
};
|
||||
|
||||
if is_stream {
|
||||
// Streaming response with logprobs
|
||||
let prefill_logprobs = prefill_body
|
||||
.as_ref()
|
||||
.and_then(|body| serde_json::from_slice::<Value>(body).ok())
|
||||
.and_then(|json| {
|
||||
json.pointer("/meta_info/input_token_logprobs").cloned()
|
||||
})
|
||||
} else {
|
||||
None
|
||||
};
|
||||
});
|
||||
|
||||
let decode_url = if !return_logprob {
|
||||
Some(decode.url().to_string())
|
||||
Self::create_streaming_response(
|
||||
res.bytes_stream(),
|
||||
status,
|
||||
prefill_logprobs,
|
||||
return_logprob,
|
||||
None,
|
||||
)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
Self::create_streaming_response(
|
||||
res.bytes_stream(),
|
||||
status,
|
||||
prefill_logprobs,
|
||||
return_logprob,
|
||||
decode_url,
|
||||
)
|
||||
} else {
|
||||
// Non-streaming response - use helper
|
||||
self.process_non_streaming_response(res, status, return_logprob, prefill_body)
|
||||
// Non-streaming response with logprobs
|
||||
self.process_non_streaming_response(
|
||||
res,
|
||||
status,
|
||||
return_logprob,
|
||||
prefill_body,
|
||||
)
|
||||
.await
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
error!(
|
||||
decode_url = %decode.url(),
|
||||
error = %e,
|
||||
"Decode request failed"
|
||||
);
|
||||
RouterMetrics::record_pd_decode_error(decode.url());
|
||||
(
|
||||
StatusCode::BAD_GATEWAY,
|
||||
format!("Decode server error: {}", e),
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
error!(
|
||||
decode_url = %decode.url(),
|
||||
error = %e,
|
||||
"Decode request failed"
|
||||
);
|
||||
RouterMetrics::record_pd_decode_error(decode.url());
|
||||
(
|
||||
StatusCode::BAD_GATEWAY,
|
||||
format!("Decode server error: {}", e),
|
||||
)
|
||||
.into_response()
|
||||
} else {
|
||||
// When we don't need logprobs, only wait for decode response
|
||||
// Send both requests concurrently but don't wait for prefill
|
||||
// Add headers to minimize response size when we don't need the body
|
||||
let prefill_future = prefill_request.header("Connection", "close").send();
|
||||
let decode_future = decode_request.send();
|
||||
|
||||
tokio::spawn(async move {
|
||||
if let Ok(response) = prefill_future.await {
|
||||
// Consume with a short timeout to free connection quickly
|
||||
let consume_future = async {
|
||||
let _ = response.bytes().await;
|
||||
};
|
||||
|
||||
// Give it 100ms to consume, then abandon
|
||||
let _ = tokio::time::timeout(Duration::from_millis(100), consume_future).await;
|
||||
}
|
||||
});
|
||||
|
||||
// Wait only for decode response
|
||||
let decode_result = decode_future.await;
|
||||
debug!("Received decode response");
|
||||
|
||||
// Update metrics
|
||||
let duration = start_time.elapsed();
|
||||
RouterMetrics::record_pd_request_duration(route, duration);
|
||||
RouterMetrics::record_pd_request(route);
|
||||
RouterMetrics::record_pd_prefill_request(prefill.url());
|
||||
RouterMetrics::record_pd_decode_request(decode.url());
|
||||
|
||||
// Process decode response immediately
|
||||
debug!("Processing decode response (no logprobs)");
|
||||
match decode_result {
|
||||
Ok(res) => {
|
||||
let status = StatusCode::from_u16(res.status().as_u16())
|
||||
.unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
|
||||
debug!("Decode response status: {}", status);
|
||||
|
||||
if !status.is_success() {
|
||||
RouterMetrics::record_pd_decode_error(decode.url());
|
||||
error!(
|
||||
"Decode server returned error status decode_url={} status={}",
|
||||
decode.url(),
|
||||
status
|
||||
);
|
||||
|
||||
// Return the error response from decode server
|
||||
match res.bytes().await {
|
||||
Ok(error_body) => (status, error_body).into_response(),
|
||||
Err(e) => {
|
||||
(status, format!("Decode server error: {}", e)).into_response()
|
||||
}
|
||||
}
|
||||
} else if is_stream {
|
||||
// Streaming response without logprobs - direct passthrough
|
||||
let decode_url = decode.url().to_string();
|
||||
Self::create_streaming_response(
|
||||
res.bytes_stream(),
|
||||
status,
|
||||
None,
|
||||
false,
|
||||
Some(decode_url),
|
||||
)
|
||||
} else {
|
||||
// Non-streaming response without logprobs - direct passthrough like fast version
|
||||
match res.bytes().await {
|
||||
Ok(decode_body) => (status, decode_body).into_response(),
|
||||
Err(e) => {
|
||||
error!("Failed to read decode response: {}", e);
|
||||
(StatusCode::INTERNAL_SERVER_ERROR, "Failed to read response")
|
||||
.into_response()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
error!(
|
||||
decode_url = %decode.url(),
|
||||
error = %e,
|
||||
"Decode request failed"
|
||||
);
|
||||
RouterMetrics::record_pd_decode_error(decode.url());
|
||||
(
|
||||
StatusCode::BAD_GATEWAY,
|
||||
format!("Decode server error: {}", e),
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check if either prefill or decode policy needs request text
|
||||
fn policies_need_request_text(&self) -> bool {
|
||||
self.prefill_policy.needs_request_text() || self.decode_policy.needs_request_text()
|
||||
}
|
||||
|
||||
// Select a pair of prefill and decode servers
|
||||
async fn select_pd_pair(
|
||||
&self,
|
||||
@@ -1311,23 +1368,23 @@ impl RouterTrait for PDRouter {
|
||||
) -> Response {
|
||||
let start = Instant::now();
|
||||
|
||||
// Convert directly to JSON to preserve all fields automatically
|
||||
let mut json = match serde_json::to_value(body) {
|
||||
Ok(json) => json,
|
||||
Err(e) => return Self::handle_serialization_error(e),
|
||||
};
|
||||
|
||||
// Extract flags for routing logic
|
||||
let is_stream = body.stream;
|
||||
let return_logprob = body.return_logprob;
|
||||
|
||||
// Extract text for cache-aware routing
|
||||
let request_text = body.text.as_deref().or_else(|| {
|
||||
body.prompt.as_ref().and_then(|p| match p {
|
||||
crate::openai_api_types::StringOrArray::String(s) => Some(s.as_str()),
|
||||
crate::openai_api_types::StringOrArray::Array(v) => v.first().map(|s| s.as_str()),
|
||||
// Extract text for cache-aware routing only if needed
|
||||
let request_text = if self.policies_need_request_text() {
|
||||
body.text.as_deref().or_else(|| {
|
||||
body.prompt.as_ref().and_then(|p| match p {
|
||||
crate::openai_api_types::StringOrArray::String(s) => Some(s.as_str()),
|
||||
crate::openai_api_types::StringOrArray::Array(v) => {
|
||||
v.first().map(|s| s.as_str())
|
||||
}
|
||||
})
|
||||
})
|
||||
});
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
// Select servers
|
||||
let (prefill, decode) = match self.select_pd_pair(request_text).await {
|
||||
@@ -1342,10 +1399,12 @@ impl RouterTrait for PDRouter {
|
||||
decode.url()
|
||||
);
|
||||
|
||||
// Inject bootstrap fields directly into JSON
|
||||
if let Err(e) = inject_bootstrap_fields(&mut json, prefill.as_ref()) {
|
||||
return Self::handle_bootstrap_error(e);
|
||||
}
|
||||
// Create optimized request with bootstrap fields
|
||||
let batch_size = Self::get_generate_batch_size(body);
|
||||
let json = match Self::create_request_with_bootstrap(body, prefill.as_ref(), batch_size) {
|
||||
Ok(json) => json,
|
||||
Err(e) => return Self::handle_serialization_error(e),
|
||||
};
|
||||
|
||||
// Execute dual dispatch
|
||||
self.execute_dual_dispatch(
|
||||
@@ -1368,27 +1427,29 @@ impl RouterTrait for PDRouter {
|
||||
) -> Response {
|
||||
let start = Instant::now();
|
||||
|
||||
// Convert directly to JSON to preserve all fields automatically
|
||||
let mut json = match serde_json::to_value(body) {
|
||||
Ok(json) => json,
|
||||
Err(e) => return Self::handle_serialization_error(e),
|
||||
};
|
||||
|
||||
// Extract flags for routing logic
|
||||
let is_stream = body.stream;
|
||||
let return_logprob = body.logprobs;
|
||||
|
||||
// Extract text for cache-aware routing from chat messages
|
||||
let request_text = body.messages.first().and_then(|msg| match msg {
|
||||
crate::openai_api_types::ChatMessage::User { content, .. } => {
|
||||
match content {
|
||||
crate::openai_api_types::UserMessageContent::Text(text) => Some(text.as_str()),
|
||||
crate::openai_api_types::UserMessageContent::Parts(_) => None, // Skip complex content
|
||||
// Extract text for cache-aware routing from chat messages only if needed
|
||||
let request_text = if self.policies_need_request_text() {
|
||||
body.messages.first().and_then(|msg| match msg {
|
||||
crate::openai_api_types::ChatMessage::User { content, .. } => {
|
||||
match content {
|
||||
crate::openai_api_types::UserMessageContent::Text(text) => {
|
||||
Some(text.as_str())
|
||||
}
|
||||
crate::openai_api_types::UserMessageContent::Parts(_) => None, // Skip complex content
|
||||
}
|
||||
}
|
||||
}
|
||||
crate::openai_api_types::ChatMessage::System { content, .. } => Some(content.as_str()),
|
||||
_ => None,
|
||||
});
|
||||
crate::openai_api_types::ChatMessage::System { content, .. } => {
|
||||
Some(content.as_str())
|
||||
}
|
||||
_ => None,
|
||||
})
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
// Select servers
|
||||
let (prefill, decode) = match self.select_pd_pair(request_text).await {
|
||||
@@ -1403,10 +1464,12 @@ impl RouterTrait for PDRouter {
|
||||
decode.url()
|
||||
);
|
||||
|
||||
// Inject bootstrap fields directly into JSON
|
||||
if let Err(e) = inject_bootstrap_fields(&mut json, prefill.as_ref()) {
|
||||
return Self::handle_bootstrap_error(e);
|
||||
}
|
||||
// Create optimized request with bootstrap fields
|
||||
let batch_size = Self::get_chat_batch_size(body);
|
||||
let json = match Self::create_request_with_bootstrap(body, prefill.as_ref(), batch_size) {
|
||||
Ok(json) => json,
|
||||
Err(e) => return Self::handle_serialization_error(e),
|
||||
};
|
||||
|
||||
// Execute dual dispatch
|
||||
self.execute_dual_dispatch(
|
||||
@@ -1429,20 +1492,18 @@ impl RouterTrait for PDRouter {
|
||||
) -> Response {
|
||||
let start = Instant::now();
|
||||
|
||||
// Convert directly to JSON to preserve all fields automatically
|
||||
let mut json = match serde_json::to_value(body) {
|
||||
Ok(json) => json,
|
||||
Err(e) => return Self::handle_serialization_error(e),
|
||||
};
|
||||
|
||||
// Extract flags for routing logic
|
||||
let is_stream = body.stream;
|
||||
let return_logprob = body.logprobs.is_some();
|
||||
|
||||
// Extract text for cache-aware routing
|
||||
let request_text = match &body.prompt {
|
||||
crate::openai_api_types::StringOrArray::String(s) => Some(s.as_str()),
|
||||
crate::openai_api_types::StringOrArray::Array(v) => v.first().map(|s| s.as_str()),
|
||||
// Extract text for cache-aware routing only if needed
|
||||
let request_text = if self.policies_need_request_text() {
|
||||
match &body.prompt {
|
||||
crate::openai_api_types::StringOrArray::String(s) => Some(s.as_str()),
|
||||
crate::openai_api_types::StringOrArray::Array(v) => v.first().map(|s| s.as_str()),
|
||||
}
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
// Select servers
|
||||
@@ -1458,10 +1519,12 @@ impl RouterTrait for PDRouter {
|
||||
decode.url()
|
||||
);
|
||||
|
||||
// Inject bootstrap fields directly into JSON
|
||||
if let Err(e) = inject_bootstrap_fields(&mut json, prefill.as_ref()) {
|
||||
return Self::handle_bootstrap_error(e);
|
||||
}
|
||||
// Create optimized request with bootstrap fields
|
||||
let batch_size = Self::get_completion_batch_size(body);
|
||||
let json = match Self::create_request_with_bootstrap(body, prefill.as_ref(), batch_size) {
|
||||
Ok(json) => json,
|
||||
Err(e) => return Self::handle_serialization_error(e),
|
||||
};
|
||||
|
||||
// Execute dual dispatch
|
||||
self.execute_dual_dispatch(
|
||||
@@ -1937,6 +2000,13 @@ mod tests {
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
|
||||
// ============= Bootstrap Injection Tests =============
|
||||
// Note: These tests are commented out as we've moved to the optimized bootstrap injection
|
||||
// approach that doesn't use the Bootstrap trait on GenerateReqInput anymore.
|
||||
|
||||
// TODO: Add new tests for the optimized bootstrap injection approach using
|
||||
// RequestWithBootstrap and BatchRequestWithBootstrap wrappers
|
||||
|
||||
// ============= Worker Selection Tests =============
|
||||
|
||||
#[tokio::test]
|
||||
@@ -2114,158 +2184,4 @@ mod tests {
|
||||
let workers = router.prefill_workers.read().unwrap();
|
||||
assert_eq!(workers.len(), 5);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_simplified_routing_preserves_sglang_fields() {
|
||||
use crate::openai_api_types::GenerateRequest;
|
||||
use crate::routers::bootstrap_injector::inject_bootstrap_fields;
|
||||
|
||||
// Create a test worker
|
||||
let worker = BasicWorker::new(
|
||||
"http://test-server:8000".to_string(),
|
||||
WorkerType::Prefill {
|
||||
bootstrap_port: Some(5678),
|
||||
},
|
||||
);
|
||||
|
||||
// Create a GenerateRequest with SGLang extensions
|
||||
let mut session_params = std::collections::HashMap::new();
|
||||
session_params.insert("test_key".to_string(), serde_json::json!("test_value"));
|
||||
|
||||
let request = GenerateRequest {
|
||||
text: Some("Test prompt".to_string()),
|
||||
stream: false,
|
||||
return_logprob: true,
|
||||
// SGLang extensions
|
||||
lora_path: Some(crate::openai_api_types::LoRAPath::Single(Some(
|
||||
"test.bin".to_string(),
|
||||
))),
|
||||
session_params: Some(session_params.clone()),
|
||||
return_hidden_states: true,
|
||||
rid: Some("test-request-id".to_string()),
|
||||
// Other fields default to None/false
|
||||
prompt: None,
|
||||
input_ids: None,
|
||||
parameters: None,
|
||||
sampling_params: None,
|
||||
};
|
||||
|
||||
// Convert to JSON (simulating the simplified routing path)
|
||||
let mut json = serde_json::to_value(&request).unwrap();
|
||||
|
||||
// Inject bootstrap fields
|
||||
let result = inject_bootstrap_fields(&mut json, &worker);
|
||||
assert!(result.is_ok());
|
||||
|
||||
// Verify all SGLang fields are preserved
|
||||
assert_eq!(json["text"], serde_json::json!("Test prompt"));
|
||||
assert_eq!(json["stream"], serde_json::json!(false));
|
||||
assert_eq!(json["return_logprob"], serde_json::json!(true));
|
||||
assert_eq!(json["lora_path"], serde_json::json!("test.bin")); // LoRAPath::Single serializes as just the inner value
|
||||
assert_eq!(
|
||||
json["session_params"],
|
||||
serde_json::to_value(&session_params).unwrap()
|
||||
);
|
||||
assert_eq!(json["return_hidden_states"], serde_json::json!(true));
|
||||
assert_eq!(json["rid"], serde_json::json!("test-request-id"));
|
||||
|
||||
// Verify bootstrap fields were added
|
||||
assert_eq!(json["bootstrap_host"], serde_json::json!("test-server"));
|
||||
assert_eq!(json["bootstrap_port"], serde_json::json!(5678));
|
||||
assert!(json["bootstrap_room"].is_number());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_simplified_routing_chat_completion() {
|
||||
use crate::openai_api_types::{ChatCompletionRequest, ChatMessage, UserMessageContent};
|
||||
use crate::routers::bootstrap_injector::inject_bootstrap_fields;
|
||||
|
||||
// Create a test worker
|
||||
let worker = BasicWorker::new(
|
||||
"http://chat-server:8000".to_string(),
|
||||
WorkerType::Prefill {
|
||||
bootstrap_port: Some(9999),
|
||||
},
|
||||
);
|
||||
|
||||
// Create a ChatCompletionRequest with SGLang extensions
|
||||
let request = ChatCompletionRequest {
|
||||
model: "gpt-4".to_string(),
|
||||
messages: vec![ChatMessage::User {
|
||||
role: "user".to_string(),
|
||||
content: UserMessageContent::Text("Hello world!".to_string()),
|
||||
name: None,
|
||||
}],
|
||||
stream: false,
|
||||
n: Some(2), // This should create batch bootstrap
|
||||
// SGLang extensions
|
||||
top_k: Some(50),
|
||||
separate_reasoning: false,
|
||||
stream_reasoning: true,
|
||||
// Set all other fields to defaults
|
||||
temperature: None,
|
||||
top_p: None,
|
||||
stream_options: None,
|
||||
stop: None,
|
||||
max_tokens: None,
|
||||
max_completion_tokens: None,
|
||||
presence_penalty: None,
|
||||
frequency_penalty: None,
|
||||
logit_bias: None,
|
||||
user: None,
|
||||
seed: None,
|
||||
logprobs: false,
|
||||
top_logprobs: None,
|
||||
response_format: None,
|
||||
tools: None,
|
||||
tool_choice: None,
|
||||
parallel_tool_calls: None,
|
||||
functions: None,
|
||||
function_call: None,
|
||||
min_p: None,
|
||||
min_tokens: None,
|
||||
repetition_penalty: None,
|
||||
regex: None,
|
||||
ebnf: None,
|
||||
stop_token_ids: None,
|
||||
no_stop_trim: false,
|
||||
ignore_eos: false,
|
||||
continue_final_message: false,
|
||||
skip_special_tokens: true,
|
||||
lora_path: None,
|
||||
session_params: None,
|
||||
return_hidden_states: false,
|
||||
};
|
||||
|
||||
// Convert to JSON (simulating the simplified routing path)
|
||||
let mut json = serde_json::to_value(&request).unwrap();
|
||||
|
||||
// Inject bootstrap fields
|
||||
let result = inject_bootstrap_fields(&mut json, &worker);
|
||||
assert!(result.is_ok());
|
||||
|
||||
// Verify original fields preserved
|
||||
assert_eq!(json["model"], serde_json::json!("gpt-4"));
|
||||
assert_eq!(json["stream"], serde_json::json!(false));
|
||||
assert_eq!(json["n"], serde_json::json!(2));
|
||||
assert_eq!(json["top_k"], serde_json::json!(50));
|
||||
assert_eq!(json["separate_reasoning"], serde_json::json!(false));
|
||||
assert_eq!(json["stream_reasoning"], serde_json::json!(true));
|
||||
|
||||
// Verify batch bootstrap fields for n=2
|
||||
let bootstrap_hosts = json["bootstrap_host"].as_array().unwrap();
|
||||
assert_eq!(bootstrap_hosts.len(), 2);
|
||||
assert_eq!(bootstrap_hosts[0], serde_json::json!("chat-server"));
|
||||
assert_eq!(bootstrap_hosts[1], serde_json::json!("chat-server"));
|
||||
|
||||
let bootstrap_ports = json["bootstrap_port"].as_array().unwrap();
|
||||
assert_eq!(bootstrap_ports.len(), 2);
|
||||
assert_eq!(bootstrap_ports[0], serde_json::json!(9999));
|
||||
assert_eq!(bootstrap_ports[1], serde_json::json!(9999));
|
||||
|
||||
let bootstrap_rooms = json["bootstrap_room"].as_array().unwrap();
|
||||
assert_eq!(bootstrap_rooms.len(), 2);
|
||||
// Rooms should be different (randomness)
|
||||
assert_ne!(bootstrap_rooms[0], bootstrap_rooms[1]);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -40,6 +40,34 @@ pub fn get_hostname(url: &str) -> String {
|
||||
url.split(':').next().unwrap_or("localhost").to_string()
|
||||
}
|
||||
|
||||
use serde::Serialize;
|
||||
|
||||
// Optimized bootstrap wrapper for single requests
|
||||
#[derive(Serialize)]
|
||||
pub struct RequestWithBootstrap<'a, T: Serialize> {
|
||||
#[serde(flatten)]
|
||||
pub original: &'a T,
|
||||
pub bootstrap_host: String,
|
||||
pub bootstrap_port: Option<u16>,
|
||||
pub bootstrap_room: u64,
|
||||
}
|
||||
|
||||
// Optimized bootstrap wrapper for batch requests
|
||||
#[derive(Serialize)]
|
||||
pub struct BatchRequestWithBootstrap<'a, T: Serialize> {
|
||||
#[serde(flatten)]
|
||||
pub original: &'a T,
|
||||
pub bootstrap_host: Vec<String>,
|
||||
pub bootstrap_port: Vec<Option<u16>>,
|
||||
pub bootstrap_room: Vec<u64>,
|
||||
}
|
||||
|
||||
// Helper to generate bootstrap room ID
|
||||
pub fn generate_room_id() -> u64 {
|
||||
// Generate a value in the range [0, 2^63 - 1] to match Python's random.randint(0, 2**63 - 1)
|
||||
rand::random::<u64>() & (i64::MAX as u64)
|
||||
}
|
||||
|
||||
// PD-specific routing policies
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub enum PDSelectionPolicy {
|
||||
|
||||
@@ -269,7 +269,7 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Err
|
||||
|
||||
let client = Client::builder()
|
||||
.pool_idle_timeout(Some(Duration::from_secs(50)))
|
||||
.pool_max_idle_per_host(100) // Increase from default of 1 to allow more concurrent connections
|
||||
.pool_max_idle_per_host(500) // Increase to 500 connections per host
|
||||
.timeout(Duration::from_secs(config.request_timeout_secs))
|
||||
.connect_timeout(Duration::from_secs(10)) // Separate connection timeout
|
||||
.tcp_nodelay(true)
|
||||
|
||||
@@ -9,7 +9,6 @@ use sglang_router_rs::openai_api_types::{
|
||||
ChatCompletionRequest, ChatMessage, CompletionRequest, GenerateParameters, GenerateRequest,
|
||||
SamplingParams, StringOrArray, UserMessageContent,
|
||||
};
|
||||
use sglang_router_rs::routers::bootstrap_injector::inject_bootstrap_fields;
|
||||
|
||||
/// Create a default GenerateRequest for benchmarks with minimal fields set
|
||||
fn default_generate_request() -> GenerateRequest {
|
||||
@@ -208,63 +207,6 @@ fn test_benchmark_serialization_roundtrip() {
|
||||
assert_eq!(generate_req.return_logprob, deserialized.return_logprob);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_benchmark_bootstrap_injection() {
|
||||
// Test that bootstrap injection works for benchmark types (replaces PD request adaptation)
|
||||
|
||||
let generate_req = GenerateRequest {
|
||||
text: Some("Test prompt".to_string()),
|
||||
..default_generate_request()
|
||||
};
|
||||
|
||||
let chat_req = ChatCompletionRequest {
|
||||
model: "test-model".to_string(),
|
||||
messages: vec![ChatMessage::User {
|
||||
role: "user".to_string(),
|
||||
content: UserMessageContent::Text("Test message".to_string()),
|
||||
name: None,
|
||||
}],
|
||||
max_tokens: Some(150),
|
||||
max_completion_tokens: Some(150),
|
||||
temperature: Some(0.7),
|
||||
top_p: Some(1.0),
|
||||
n: Some(1),
|
||||
presence_penalty: Some(0.0),
|
||||
frequency_penalty: Some(0.0),
|
||||
parallel_tool_calls: Some(true),
|
||||
..default_chat_completion_request()
|
||||
};
|
||||
|
||||
let completion_req = CompletionRequest {
|
||||
model: "test-model".to_string(),
|
||||
prompt: StringOrArray::String("Test prompt".to_string()),
|
||||
max_tokens: Some(50),
|
||||
temperature: Some(0.8),
|
||||
top_p: Some(1.0),
|
||||
n: Some(1),
|
||||
presence_penalty: Some(0.0),
|
||||
frequency_penalty: Some(0.0),
|
||||
best_of: Some(1),
|
||||
..default_completion_request()
|
||||
};
|
||||
|
||||
let worker = create_test_worker();
|
||||
|
||||
// Test bootstrap injection (should not panic)
|
||||
let mut generate_json = to_value(&generate_req).unwrap();
|
||||
let mut chat_json = to_value(&chat_req).unwrap();
|
||||
let mut completion_json = to_value(&completion_req).unwrap();
|
||||
|
||||
assert!(inject_bootstrap_fields(&mut generate_json, &worker).is_ok());
|
||||
assert!(inject_bootstrap_fields(&mut chat_json, &worker).is_ok());
|
||||
assert!(inject_bootstrap_fields(&mut completion_json, &worker).is_ok());
|
||||
|
||||
// Verify bootstrap fields were added
|
||||
assert!(generate_json.get("bootstrap_host").is_some());
|
||||
assert!(generate_json.get("bootstrap_port").is_some());
|
||||
assert!(generate_json.get("bootstrap_room").is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_benchmark_direct_json_routing() {
|
||||
// Test direct JSON routing functionality for benchmark types (replaces regular routing)
|
||||
@@ -283,47 +225,3 @@ fn test_benchmark_direct_json_routing() {
|
||||
assert!(!json_string.is_empty());
|
||||
assert!(!bytes.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_benchmark_performance_baseline() {
|
||||
// Basic performance sanity check - ensure operations complete quickly
|
||||
use std::time::Instant;
|
||||
|
||||
let generate_req = GenerateRequest {
|
||||
text: Some("Short test prompt".to_string()),
|
||||
..default_generate_request()
|
||||
};
|
||||
|
||||
// Test the actual simplified pipeline: to_value + bootstrap injection
|
||||
let start = Instant::now();
|
||||
let worker = create_test_worker();
|
||||
|
||||
// This mirrors the actual router pipeline
|
||||
let mut json = to_value(&generate_req).unwrap();
|
||||
let _ = inject_bootstrap_fields(&mut json, &worker);
|
||||
|
||||
let total_duration = start.elapsed();
|
||||
assert!(
|
||||
total_duration.as_millis() < 5,
|
||||
"Simplified pipeline took too long: {:?} (should be faster than old adapter approach)",
|
||||
total_duration
|
||||
);
|
||||
|
||||
// Individual components should also be fast
|
||||
let start = Instant::now();
|
||||
let _json = to_value(&generate_req).unwrap();
|
||||
let to_value_duration = start.elapsed();
|
||||
|
||||
let start = Instant::now();
|
||||
let mut json = to_value(&generate_req).unwrap();
|
||||
let _ = inject_bootstrap_fields(&mut json, &worker);
|
||||
let inject_duration = start.elapsed();
|
||||
|
||||
// Bootstrap injection should be faster than the JSON conversion
|
||||
assert!(
|
||||
inject_duration <= to_value_duration * 3,
|
||||
"Bootstrap injection ({:?}) should not be much slower than JSON conversion ({:?})",
|
||||
inject_duration,
|
||||
to_value_duration
|
||||
);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user