[router] PD Router Simplification and Reorganization (#8838)
This commit is contained in:
@@ -1,12 +1,22 @@
|
|||||||
use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion, Throughput};
|
use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion, Throughput};
|
||||||
use serde_json::{from_str, to_string, to_vec};
|
use serde_json::{from_str, to_string, to_value, to_vec};
|
||||||
use std::time::Instant;
|
use std::time::Instant;
|
||||||
|
|
||||||
|
use sglang_router_rs::core::{BasicWorker, WorkerType};
|
||||||
use sglang_router_rs::openai_api_types::{
|
use sglang_router_rs::openai_api_types::{
|
||||||
ChatCompletionRequest, ChatMessage, CompletionRequest, GenerateParameters, GenerateRequest,
|
ChatCompletionRequest, ChatMessage, CompletionRequest, GenerateParameters, GenerateRequest,
|
||||||
SamplingParams, StringOrArray, UserMessageContent,
|
SamplingParams, StringOrArray, UserMessageContent,
|
||||||
};
|
};
|
||||||
use sglang_router_rs::routers::request_adapter::{RouteableRequest, ToPdRequest};
|
use sglang_router_rs::routers::bootstrap_injector::inject_bootstrap_fields;
|
||||||
|
|
||||||
|
fn create_test_worker() -> BasicWorker {
|
||||||
|
BasicWorker::new(
|
||||||
|
"http://test-server:8000".to_string(),
|
||||||
|
WorkerType::Prefill {
|
||||||
|
bootstrap_port: Some(5678),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
/// Create a default GenerateRequest for benchmarks with minimal fields set
|
/// Create a default GenerateRequest for benchmarks with minimal fields set
|
||||||
fn default_generate_request() -> GenerateRequest {
|
fn default_generate_request() -> GenerateRequest {
|
||||||
@@ -312,49 +322,54 @@ fn bench_json_deserialization(c: &mut Criterion) {
|
|||||||
group.finish();
|
group.finish();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Benchmark request adaptation from OpenAI to PD format
|
// Benchmark bootstrap injection (replaces request adaptation)
|
||||||
fn bench_request_adaptation(c: &mut Criterion) {
|
fn bench_bootstrap_injection(c: &mut Criterion) {
|
||||||
let mut group = c.benchmark_group("request_adaptation");
|
let mut group = c.benchmark_group("bootstrap_injection");
|
||||||
|
|
||||||
let generate_req = create_sample_generate_request();
|
let generate_req = create_sample_generate_request();
|
||||||
let chat_req = create_sample_chat_completion_request();
|
let chat_req = create_sample_chat_completion_request();
|
||||||
let completion_req = create_sample_completion_request();
|
let completion_req = create_sample_completion_request();
|
||||||
let large_chat_req = create_large_chat_completion_request();
|
let large_chat_req = create_large_chat_completion_request();
|
||||||
|
let worker = create_test_worker();
|
||||||
|
|
||||||
group.bench_function("generate_to_pd", |b| {
|
group.bench_function("generate_bootstrap_injection", |b| {
|
||||||
b.iter(|| {
|
b.iter(|| {
|
||||||
let pd_req = black_box(generate_req.clone()).to_pd_request();
|
let mut json = to_value(black_box(&generate_req)).unwrap();
|
||||||
black_box(pd_req);
|
inject_bootstrap_fields(&mut json, &worker).unwrap();
|
||||||
|
black_box(json);
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
group.bench_function("chat_completion_to_pd", |b| {
|
group.bench_function("chat_completion_bootstrap_injection", |b| {
|
||||||
b.iter(|| {
|
b.iter(|| {
|
||||||
let pd_req = black_box(chat_req.clone()).to_pd_request();
|
let mut json = to_value(black_box(&chat_req)).unwrap();
|
||||||
black_box(pd_req);
|
inject_bootstrap_fields(&mut json, &worker).unwrap();
|
||||||
|
black_box(json);
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
group.bench_function("completion_to_pd", |b| {
|
group.bench_function("completion_bootstrap_injection", |b| {
|
||||||
b.iter(|| {
|
b.iter(|| {
|
||||||
let pd_req = black_box(completion_req.clone()).to_pd_request();
|
let mut json = to_value(black_box(&completion_req)).unwrap();
|
||||||
black_box(pd_req);
|
inject_bootstrap_fields(&mut json, &worker).unwrap();
|
||||||
|
black_box(json);
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
group.bench_function("large_chat_completion_to_pd", |b| {
|
group.bench_function("large_chat_completion_bootstrap_injection", |b| {
|
||||||
b.iter(|| {
|
b.iter(|| {
|
||||||
let pd_req = black_box(large_chat_req.clone()).to_pd_request();
|
let mut json = to_value(black_box(&large_chat_req)).unwrap();
|
||||||
black_box(pd_req);
|
inject_bootstrap_fields(&mut json, &worker).unwrap();
|
||||||
|
black_box(json);
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
group.finish();
|
group.finish();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Benchmark regular routing (RouteableRequest methods)
|
// Benchmark direct JSON routing (replaces regular routing)
|
||||||
fn bench_regular_routing(c: &mut Criterion) {
|
fn bench_direct_json_routing(c: &mut Criterion) {
|
||||||
let mut group = c.benchmark_group("regular_routing");
|
let mut group = c.benchmark_group("direct_json_routing");
|
||||||
|
|
||||||
let generate_req = create_sample_generate_request();
|
let generate_req = create_sample_generate_request();
|
||||||
let chat_req = create_sample_chat_completion_request();
|
let chat_req = create_sample_chat_completion_request();
|
||||||
@@ -362,35 +377,42 @@ fn bench_regular_routing(c: &mut Criterion) {
|
|||||||
|
|
||||||
group.bench_function("generate_to_json", |b| {
|
group.bench_function("generate_to_json", |b| {
|
||||||
b.iter(|| {
|
b.iter(|| {
|
||||||
let json = black_box(&generate_req).to_json().unwrap();
|
let json = to_value(black_box(&generate_req)).unwrap();
|
||||||
|
black_box(json);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
group.bench_function("generate_to_json_string", |b| {
|
||||||
|
b.iter(|| {
|
||||||
|
let json = to_string(black_box(&generate_req)).unwrap();
|
||||||
black_box(json);
|
black_box(json);
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
group.bench_function("generate_to_bytes", |b| {
|
group.bench_function("generate_to_bytes", |b| {
|
||||||
b.iter(|| {
|
b.iter(|| {
|
||||||
let bytes = black_box(&generate_req).to_bytes().unwrap();
|
let bytes = to_vec(black_box(&generate_req)).unwrap();
|
||||||
black_box(bytes);
|
black_box(bytes);
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
group.bench_function("chat_completion_to_json", |b| {
|
group.bench_function("chat_completion_to_json", |b| {
|
||||||
b.iter(|| {
|
b.iter(|| {
|
||||||
let json = black_box(&chat_req).to_json().unwrap();
|
let json = to_value(black_box(&chat_req)).unwrap();
|
||||||
black_box(json);
|
black_box(json);
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
group.bench_function("chat_completion_to_bytes", |b| {
|
group.bench_function("chat_completion_to_json_string", |b| {
|
||||||
b.iter(|| {
|
b.iter(|| {
|
||||||
let bytes = black_box(&chat_req).to_bytes().unwrap();
|
let json = to_string(black_box(&chat_req)).unwrap();
|
||||||
black_box(bytes);
|
black_box(json);
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
group.bench_function("completion_to_json", |b| {
|
group.bench_function("completion_to_json", |b| {
|
||||||
b.iter(|| {
|
b.iter(|| {
|
||||||
let json = black_box(&completion_req).to_json().unwrap();
|
let json = to_value(black_box(&completion_req)).unwrap();
|
||||||
black_box(json);
|
black_box(json);
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
@@ -418,6 +440,8 @@ fn bench_throughput_by_size(c: &mut Criterion) {
|
|||||||
..default_generate_request()
|
..default_generate_request()
|
||||||
};
|
};
|
||||||
|
|
||||||
|
let worker = create_test_worker();
|
||||||
|
|
||||||
for (name, req) in [
|
for (name, req) in [
|
||||||
("small", &small_generate),
|
("small", &small_generate),
|
||||||
("medium", &medium_generate),
|
("medium", &medium_generate),
|
||||||
@@ -445,33 +469,41 @@ fn bench_throughput_by_size(c: &mut Criterion) {
|
|||||||
},
|
},
|
||||||
);
|
);
|
||||||
|
|
||||||
group.bench_with_input(BenchmarkId::new("adapt_to_pd", name), &req, |b, req| {
|
group.bench_with_input(
|
||||||
b.iter(|| {
|
BenchmarkId::new("bootstrap_inject", name),
|
||||||
let pd_req = (*req).clone().to_pd_request();
|
&req,
|
||||||
black_box(pd_req);
|
|b, req| {
|
||||||
});
|
b.iter(|| {
|
||||||
});
|
let mut json = to_value(req).unwrap();
|
||||||
|
inject_bootstrap_fields(&mut json, &worker).unwrap();
|
||||||
|
black_box(json);
|
||||||
|
});
|
||||||
|
},
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
group.finish();
|
group.finish();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Benchmark full round-trip: deserialize -> adapt -> serialize
|
// Benchmark full round-trip: deserialize -> inject bootstrap -> serialize
|
||||||
fn bench_full_round_trip(c: &mut Criterion) {
|
fn bench_full_round_trip(c: &mut Criterion) {
|
||||||
let mut group = c.benchmark_group("full_round_trip");
|
let mut group = c.benchmark_group("full_round_trip");
|
||||||
|
|
||||||
let generate_json = to_string(&create_sample_generate_request()).unwrap();
|
let generate_json = to_string(&create_sample_generate_request()).unwrap();
|
||||||
let chat_json = to_string(&create_sample_chat_completion_request()).unwrap();
|
let chat_json = to_string(&create_sample_chat_completion_request()).unwrap();
|
||||||
let completion_json = to_string(&create_sample_completion_request()).unwrap();
|
let completion_json = to_string(&create_sample_completion_request()).unwrap();
|
||||||
|
let worker = create_test_worker();
|
||||||
|
|
||||||
group.bench_function("generate_openai_to_pd_pipeline", |b| {
|
group.bench_function("generate_openai_to_pd_pipeline", |b| {
|
||||||
b.iter(|| {
|
b.iter(|| {
|
||||||
// Deserialize OpenAI request
|
// Deserialize OpenAI request
|
||||||
let req: GenerateRequest = from_str(black_box(&generate_json)).unwrap();
|
let req: GenerateRequest = from_str(black_box(&generate_json)).unwrap();
|
||||||
// Adapt to PD format
|
// Convert to JSON Value
|
||||||
let pd_req = req.to_pd_request();
|
let mut json = to_value(&req).unwrap();
|
||||||
// Serialize PD request
|
// Inject bootstrap fields
|
||||||
let pd_json = to_string(&pd_req).unwrap();
|
inject_bootstrap_fields(&mut json, &worker).unwrap();
|
||||||
|
// Serialize final request
|
||||||
|
let pd_json = to_string(&json).unwrap();
|
||||||
black_box(pd_json);
|
black_box(pd_json);
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
@@ -479,8 +511,9 @@ fn bench_full_round_trip(c: &mut Criterion) {
|
|||||||
group.bench_function("chat_completion_openai_to_pd_pipeline", |b| {
|
group.bench_function("chat_completion_openai_to_pd_pipeline", |b| {
|
||||||
b.iter(|| {
|
b.iter(|| {
|
||||||
let req: ChatCompletionRequest = from_str(black_box(&chat_json)).unwrap();
|
let req: ChatCompletionRequest = from_str(black_box(&chat_json)).unwrap();
|
||||||
let pd_req = req.to_pd_request();
|
let mut json = to_value(&req).unwrap();
|
||||||
let pd_json = to_string(&pd_req).unwrap();
|
inject_bootstrap_fields(&mut json, &worker).unwrap();
|
||||||
|
let pd_json = to_string(&json).unwrap();
|
||||||
black_box(pd_json);
|
black_box(pd_json);
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
@@ -488,19 +521,21 @@ fn bench_full_round_trip(c: &mut Criterion) {
|
|||||||
group.bench_function("completion_openai_to_pd_pipeline", |b| {
|
group.bench_function("completion_openai_to_pd_pipeline", |b| {
|
||||||
b.iter(|| {
|
b.iter(|| {
|
||||||
let req: CompletionRequest = from_str(black_box(&completion_json)).unwrap();
|
let req: CompletionRequest = from_str(black_box(&completion_json)).unwrap();
|
||||||
let pd_req = req.to_pd_request();
|
let mut json = to_value(&req).unwrap();
|
||||||
let pd_json = to_string(&pd_req).unwrap();
|
inject_bootstrap_fields(&mut json, &worker).unwrap();
|
||||||
|
let pd_json = to_string(&json).unwrap();
|
||||||
black_box(pd_json);
|
black_box(pd_json);
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
group.bench_function("generate_regular_routing_pipeline", |b| {
|
group.bench_function("generate_direct_json_pipeline", |b| {
|
||||||
b.iter(|| {
|
b.iter(|| {
|
||||||
// Deserialize OpenAI request
|
// Deserialize OpenAI request
|
||||||
let req: GenerateRequest = from_str(black_box(&generate_json)).unwrap();
|
let req: GenerateRequest = from_str(black_box(&generate_json)).unwrap();
|
||||||
// Convert to JSON for regular routing
|
// Convert to JSON for direct routing (no bootstrap injection)
|
||||||
let routing_json = req.to_json().unwrap();
|
let routing_json = to_value(&req).unwrap();
|
||||||
black_box(routing_json);
|
let json_string = to_string(&routing_json).unwrap();
|
||||||
|
black_box(json_string);
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
@@ -515,6 +550,7 @@ fn benchmark_summary(c: &mut Criterion) {
|
|||||||
|
|
||||||
// Quick performance overview
|
// Quick performance overview
|
||||||
let generate_req = create_sample_generate_request();
|
let generate_req = create_sample_generate_request();
|
||||||
|
let worker = create_test_worker();
|
||||||
|
|
||||||
println!("\nQuick Performance Overview:");
|
println!("\nQuick Performance Overview:");
|
||||||
|
|
||||||
@@ -538,32 +574,39 @@ fn benchmark_summary(c: &mut Criterion) {
|
|||||||
deserialize_time
|
deserialize_time
|
||||||
);
|
);
|
||||||
|
|
||||||
// Measure adaptation
|
// Measure bootstrap injection (replaces adaptation)
|
||||||
let start = Instant::now();
|
let start = Instant::now();
|
||||||
for _ in 0..1000 {
|
for _ in 0..1000 {
|
||||||
let _ = black_box(generate_req.clone().to_pd_request());
|
let mut json = to_value(&generate_req).unwrap();
|
||||||
|
let _ = black_box(inject_bootstrap_fields(&mut json, &worker));
|
||||||
}
|
}
|
||||||
let adapt_time = start.elapsed().as_nanos() / 1000;
|
let inject_time = start.elapsed().as_nanos() / 1000;
|
||||||
println!(" * PD Adaptation (avg): {:>8} ns/req", adapt_time);
|
println!(" * Bootstrap Injection (avg): {:>6} ns/req", inject_time);
|
||||||
|
|
||||||
// Calculate ratios
|
// Calculate ratios
|
||||||
let total_pipeline = serialize_time + deserialize_time + adapt_time;
|
let total_pipeline = serialize_time + deserialize_time + inject_time;
|
||||||
println!(" * Total Pipeline (avg): {:>8} ns/req", total_pipeline);
|
println!(" * Total Pipeline (avg): {:>8} ns/req", total_pipeline);
|
||||||
|
|
||||||
println!("\nPerformance Insights:");
|
println!("\nPerformance Insights:");
|
||||||
if deserialize_time > serialize_time * 2 {
|
if deserialize_time > serialize_time * 2 {
|
||||||
println!(" • Deserialization is significantly faster than serialization");
|
println!(" • Deserialization is significantly faster than serialization");
|
||||||
}
|
}
|
||||||
if adapt_time < serialize_time / 10 {
|
if inject_time < serialize_time / 10 {
|
||||||
println!(
|
println!(
|
||||||
" • PD adaptation overhead is negligible ({:.1}% of serialization)",
|
" • Bootstrap injection overhead is negligible ({:.1}% of serialization)",
|
||||||
(adapt_time as f64 / serialize_time as f64) * 100.0
|
(inject_time as f64 / serialize_time as f64) * 100.0
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
if total_pipeline < 10_000 {
|
if total_pipeline < 100_000 {
|
||||||
println!(" • Total pipeline latency is excellent (< 10μs)");
|
println!(" • Total pipeline latency is excellent (< 100μs)");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
println!("\nSimplification Benefits:");
|
||||||
|
println!(" • Eliminated complex type conversion layer");
|
||||||
|
println!(" • Reduced memory allocations");
|
||||||
|
println!(" • Automatic field preservation (no manual mapping)");
|
||||||
|
println!(" • Direct JSON manipulation improves performance");
|
||||||
|
|
||||||
println!("\nRecommendations:");
|
println!("\nRecommendations:");
|
||||||
if serialize_time > deserialize_time {
|
if serialize_time > deserialize_time {
|
||||||
println!(" • Focus optimization efforts on serialization rather than deserialization");
|
println!(" • Focus optimization efforts on serialization rather than deserialization");
|
||||||
@@ -581,8 +624,8 @@ criterion_group!(
|
|||||||
benchmark_summary,
|
benchmark_summary,
|
||||||
bench_json_serialization,
|
bench_json_serialization,
|
||||||
bench_json_deserialization,
|
bench_json_deserialization,
|
||||||
bench_request_adaptation,
|
bench_bootstrap_injection,
|
||||||
bench_regular_routing,
|
bench_direct_json_routing,
|
||||||
bench_throughput_by_size,
|
bench_throughput_by_size,
|
||||||
bench_full_round_trip
|
bench_full_round_trip
|
||||||
);
|
);
|
||||||
|
|||||||
@@ -121,8 +121,6 @@ class BenchmarkRunner:
|
|||||||
results["serialization_time"] = self._extract_time(line)
|
results["serialization_time"] = self._extract_time(line)
|
||||||
elif "Deserialization (avg):" in line:
|
elif "Deserialization (avg):" in line:
|
||||||
results["deserialization_time"] = self._extract_time(line)
|
results["deserialization_time"] = self._extract_time(line)
|
||||||
elif "PD Adaptation (avg):" in line:
|
|
||||||
results["adaptation_time"] = self._extract_time(line)
|
|
||||||
elif "Total Pipeline (avg):" in line:
|
elif "Total Pipeline (avg):" in line:
|
||||||
results["total_time"] = self._extract_time(line)
|
results["total_time"] = self._extract_time(line)
|
||||||
|
|
||||||
@@ -145,7 +143,6 @@ class BenchmarkRunner:
|
|||||||
thresholds = {
|
thresholds = {
|
||||||
"serialization_time": 2000, # 2μs max
|
"serialization_time": 2000, # 2μs max
|
||||||
"deserialization_time": 2000, # 2μs max
|
"deserialization_time": 2000, # 2μs max
|
||||||
"adaptation_time": 5000, # 5μs max
|
|
||||||
"total_time": 10000, # 10μs max
|
"total_time": 10000, # 10μs max
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
334
sgl-router/src/routers/bootstrap_injector.rs
Normal file
334
sgl-router/src/routers/bootstrap_injector.rs
Normal file
@@ -0,0 +1,334 @@
|
|||||||
|
// Bootstrap field injection for PD routing
|
||||||
|
// Directly injects bootstrap fields into JSON requests without intermediate type conversions
|
||||||
|
|
||||||
|
use crate::core::{Worker, WorkerType};
|
||||||
|
use crate::routers::pd_types::get_hostname;
|
||||||
|
use serde_json::{json, Value};
|
||||||
|
|
||||||
|
/// Inject bootstrap fields directly into a JSON request
|
||||||
|
/// This replaces the complex ToPdRequest -> Bootstrap trait pattern
|
||||||
|
pub fn inject_bootstrap_fields(json: &mut Value, worker: &dyn Worker) -> Result<(), String> {
|
||||||
|
let batch_size = extract_batch_size(json)?;
|
||||||
|
|
||||||
|
// Extract bootstrap port from prefill worker if it's a prefill type
|
||||||
|
let bootstrap_port = match worker.worker_type() {
|
||||||
|
WorkerType::Prefill { bootstrap_port } => bootstrap_port,
|
||||||
|
_ => None,
|
||||||
|
};
|
||||||
|
|
||||||
|
let hostname = get_hostname(worker.url());
|
||||||
|
|
||||||
|
if let Some(batch_size) = batch_size {
|
||||||
|
// Batch scenario - create arrays of bootstrap values
|
||||||
|
json["bootstrap_host"] = json!(vec![hostname; batch_size]);
|
||||||
|
json["bootstrap_port"] = json!(vec![bootstrap_port; batch_size]);
|
||||||
|
json["bootstrap_room"] = json!((0..batch_size)
|
||||||
|
.map(|_| {
|
||||||
|
// Generate a value in the range [0, 2^63 - 1] to match Python's random.randint(0, 2**63 - 1)
|
||||||
|
rand::random::<u64>() & (i64::MAX as u64)
|
||||||
|
})
|
||||||
|
.collect::<Vec<_>>());
|
||||||
|
} else {
|
||||||
|
// Single scenario - create single bootstrap values
|
||||||
|
json["bootstrap_host"] = json!(hostname);
|
||||||
|
json["bootstrap_port"] = json!(bootstrap_port);
|
||||||
|
json["bootstrap_room"] = json!(rand::random::<u64>() & (i64::MAX as u64));
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Extract batch size from various JSON request formats
|
||||||
|
/// Handles chat completions, completions, and generate requests
|
||||||
|
fn extract_batch_size(json: &Value) -> Result<Option<usize>, String> {
|
||||||
|
// Check for chat completions 'n' parameter (number of choices)
|
||||||
|
if let Some(n) = json.get("n").and_then(|v| v.as_u64()) {
|
||||||
|
if n > 1 {
|
||||||
|
return Ok(Some(n as usize));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for array prompts (completions API)
|
||||||
|
if let Some(prompt) = json.get("prompt") {
|
||||||
|
if let Some(arr) = prompt.as_array() {
|
||||||
|
if arr.is_empty() {
|
||||||
|
return Err("Batch prompt array is empty".to_string());
|
||||||
|
}
|
||||||
|
return Ok(Some(arr.len()));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for array texts (generate API)
|
||||||
|
if let Some(text) = json.get("text") {
|
||||||
|
if let Some(arr) = text.as_array() {
|
||||||
|
if arr.is_empty() {
|
||||||
|
return Err("Batch text array is empty".to_string());
|
||||||
|
}
|
||||||
|
return Ok(Some(arr.len()));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for batch input_ids (generate API)
|
||||||
|
if let Some(input_ids) = json.get("input_ids") {
|
||||||
|
if let Some(arr) = input_ids.as_array() {
|
||||||
|
if arr.is_empty() {
|
||||||
|
return Err("Batch input_ids array is empty".to_string());
|
||||||
|
}
|
||||||
|
return Ok(Some(arr.len()));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// No batch indicators found - single request
|
||||||
|
Ok(None)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use crate::core::BasicWorker;
|
||||||
|
use serde_json::json;
|
||||||
|
|
||||||
|
fn create_test_worker() -> BasicWorker {
|
||||||
|
BasicWorker::new(
|
||||||
|
"http://test-server:8000".to_string(),
|
||||||
|
WorkerType::Prefill {
|
||||||
|
bootstrap_port: Some(5678),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_inject_bootstrap_single_request() {
|
||||||
|
let worker = create_test_worker();
|
||||||
|
let mut json = json!({
|
||||||
|
"model": "test-model",
|
||||||
|
"prompt": "Hello world",
|
||||||
|
"max_tokens": 100
|
||||||
|
});
|
||||||
|
|
||||||
|
let result = inject_bootstrap_fields(&mut json, &worker);
|
||||||
|
assert!(result.is_ok());
|
||||||
|
|
||||||
|
// Verify bootstrap fields were added
|
||||||
|
assert_eq!(json["bootstrap_host"], json!("test-server"));
|
||||||
|
assert_eq!(json["bootstrap_port"], json!(5678));
|
||||||
|
assert!(json["bootstrap_room"].is_number());
|
||||||
|
|
||||||
|
// Verify original fields preserved
|
||||||
|
assert_eq!(json["model"], json!("test-model"));
|
||||||
|
assert_eq!(json["prompt"], json!("Hello world"));
|
||||||
|
assert_eq!(json["max_tokens"], json!(100));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_inject_bootstrap_batch_prompt() {
|
||||||
|
let worker = create_test_worker();
|
||||||
|
let mut json = json!({
|
||||||
|
"model": "test-model",
|
||||||
|
"prompt": ["Hello", "World"],
|
||||||
|
"max_tokens": 100
|
||||||
|
});
|
||||||
|
|
||||||
|
let result = inject_bootstrap_fields(&mut json, &worker);
|
||||||
|
assert!(result.is_ok());
|
||||||
|
|
||||||
|
// Verify batch bootstrap fields
|
||||||
|
assert_eq!(
|
||||||
|
json["bootstrap_host"],
|
||||||
|
json!(["test-server", "test-server"])
|
||||||
|
);
|
||||||
|
assert_eq!(json["bootstrap_port"], json!([5678, 5678]));
|
||||||
|
|
||||||
|
let bootstrap_rooms = json["bootstrap_room"].as_array().unwrap();
|
||||||
|
assert_eq!(bootstrap_rooms.len(), 2);
|
||||||
|
for room in bootstrap_rooms {
|
||||||
|
assert!(room.is_number());
|
||||||
|
let room_val = room.as_u64().unwrap();
|
||||||
|
assert!(room_val <= i64::MAX as u64);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_inject_bootstrap_chat_n_parameter() {
|
||||||
|
let worker = create_test_worker();
|
||||||
|
let mut json = json!({
|
||||||
|
"model": "gpt-4",
|
||||||
|
"messages": [{"role": "user", "content": "Hello"}],
|
||||||
|
"n": 3
|
||||||
|
});
|
||||||
|
|
||||||
|
let result = inject_bootstrap_fields(&mut json, &worker);
|
||||||
|
assert!(result.is_ok());
|
||||||
|
|
||||||
|
// Verify batch bootstrap fields for n=3
|
||||||
|
let bootstrap_hosts = json["bootstrap_host"].as_array().unwrap();
|
||||||
|
assert_eq!(bootstrap_hosts.len(), 3);
|
||||||
|
assert_eq!(bootstrap_hosts[0], json!("test-server"));
|
||||||
|
|
||||||
|
let bootstrap_ports = json["bootstrap_port"].as_array().unwrap();
|
||||||
|
assert_eq!(bootstrap_ports.len(), 3);
|
||||||
|
assert_eq!(bootstrap_ports[0], json!(5678));
|
||||||
|
|
||||||
|
let bootstrap_rooms = json["bootstrap_room"].as_array().unwrap();
|
||||||
|
assert_eq!(bootstrap_rooms.len(), 3);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_inject_bootstrap_generate_text_array() {
|
||||||
|
let worker = create_test_worker();
|
||||||
|
let mut json = json!({
|
||||||
|
"text": ["First prompt", "Second prompt"],
|
||||||
|
"stream": false
|
||||||
|
});
|
||||||
|
|
||||||
|
let result = inject_bootstrap_fields(&mut json, &worker);
|
||||||
|
assert!(result.is_ok());
|
||||||
|
|
||||||
|
// Verify batch bootstrap fields
|
||||||
|
let bootstrap_hosts = json["bootstrap_host"].as_array().unwrap();
|
||||||
|
assert_eq!(bootstrap_hosts.len(), 2);
|
||||||
|
|
||||||
|
let bootstrap_rooms = json["bootstrap_room"].as_array().unwrap();
|
||||||
|
assert_eq!(bootstrap_rooms.len(), 2);
|
||||||
|
// Ensure room values are different (randomness)
|
||||||
|
assert_ne!(bootstrap_rooms[0], bootstrap_rooms[1]);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_inject_bootstrap_input_ids_array() {
|
||||||
|
let worker = create_test_worker();
|
||||||
|
let mut json = json!({
|
||||||
|
"input_ids": [[1, 2, 3], [4, 5, 6]],
|
||||||
|
"stream": false
|
||||||
|
});
|
||||||
|
|
||||||
|
let result = inject_bootstrap_fields(&mut json, &worker);
|
||||||
|
assert!(result.is_ok());
|
||||||
|
|
||||||
|
// Verify batch bootstrap fields
|
||||||
|
let bootstrap_hosts = json["bootstrap_host"].as_array().unwrap();
|
||||||
|
assert_eq!(bootstrap_hosts.len(), 2);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_extract_batch_size_empty_array_error() {
|
||||||
|
let json = json!({
|
||||||
|
"prompt": [],
|
||||||
|
"model": "test"
|
||||||
|
});
|
||||||
|
|
||||||
|
let result = extract_batch_size(&json);
|
||||||
|
assert!(result.is_err());
|
||||||
|
assert!(result.unwrap_err().contains("empty"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_extract_batch_size_single_requests() {
|
||||||
|
// Single string prompt
|
||||||
|
let json = json!({
|
||||||
|
"prompt": "Hello world",
|
||||||
|
"model": "test"
|
||||||
|
});
|
||||||
|
assert_eq!(extract_batch_size(&json).unwrap(), None);
|
||||||
|
|
||||||
|
// Single text
|
||||||
|
let json = json!({
|
||||||
|
"text": "Hello world",
|
||||||
|
"stream": false
|
||||||
|
});
|
||||||
|
assert_eq!(extract_batch_size(&json).unwrap(), None);
|
||||||
|
|
||||||
|
// Chat with n=1 (default)
|
||||||
|
let json = json!({
|
||||||
|
"messages": [{"role": "user", "content": "Hello"}],
|
||||||
|
"n": 1
|
||||||
|
});
|
||||||
|
assert_eq!(extract_batch_size(&json).unwrap(), None);
|
||||||
|
|
||||||
|
// Chat without n parameter
|
||||||
|
let json = json!({
|
||||||
|
"messages": [{"role": "user", "content": "Hello"}]
|
||||||
|
});
|
||||||
|
assert_eq!(extract_batch_size(&json).unwrap(), None);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_inject_bootstrap_preserves_sglang_fields() {
|
||||||
|
let worker = create_test_worker();
|
||||||
|
let mut json = json!({
|
||||||
|
"model": "test-model",
|
||||||
|
"prompt": "Hello",
|
||||||
|
// SGLang extensions should be preserved
|
||||||
|
"top_k": 40,
|
||||||
|
"min_p": 0.05,
|
||||||
|
"repetition_penalty": 1.1,
|
||||||
|
"regex": "test_pattern",
|
||||||
|
"lora_path": "test.bin",
|
||||||
|
"no_stop_trim": true,
|
||||||
|
"ignore_eos": false
|
||||||
|
});
|
||||||
|
|
||||||
|
let result = inject_bootstrap_fields(&mut json, &worker);
|
||||||
|
assert!(result.is_ok());
|
||||||
|
|
||||||
|
// Verify bootstrap fields added
|
||||||
|
assert!(json.get("bootstrap_host").is_some());
|
||||||
|
assert!(json.get("bootstrap_port").is_some());
|
||||||
|
assert!(json.get("bootstrap_room").is_some());
|
||||||
|
|
||||||
|
// Verify all SGLang fields preserved
|
||||||
|
assert_eq!(json["top_k"], json!(40));
|
||||||
|
assert_eq!(json["min_p"], json!(0.05));
|
||||||
|
assert_eq!(json["repetition_penalty"], json!(1.1));
|
||||||
|
assert_eq!(json["regex"], json!("test_pattern"));
|
||||||
|
assert_eq!(json["lora_path"], json!("test.bin"));
|
||||||
|
assert_eq!(json["no_stop_trim"], json!(true));
|
||||||
|
assert_eq!(json["ignore_eos"], json!(false));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_bootstrap_room_range() {
|
||||||
|
let worker = create_test_worker();
|
||||||
|
|
||||||
|
// Test single request room generation
|
||||||
|
for _ in 0..1000 {
|
||||||
|
let mut json = json!({"prompt": "test"});
|
||||||
|
inject_bootstrap_fields(&mut json, &worker).unwrap();
|
||||||
|
|
||||||
|
let room = json["bootstrap_room"].as_u64().unwrap();
|
||||||
|
assert!(room <= i64::MAX as u64, "Room {} exceeds i64::MAX", room);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test batch request room generation
|
||||||
|
for _ in 0..100 {
|
||||||
|
let mut json = json!({"prompt": ["test1", "test2"]});
|
||||||
|
inject_bootstrap_fields(&mut json, &worker).unwrap();
|
||||||
|
|
||||||
|
let rooms = json["bootstrap_room"].as_array().unwrap();
|
||||||
|
for room_val in rooms {
|
||||||
|
let room = room_val.as_u64().unwrap();
|
||||||
|
assert!(room <= i64::MAX as u64, "Room {} exceeds i64::MAX", room);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_worker_without_bootstrap_port() {
|
||||||
|
let worker = BasicWorker::new(
|
||||||
|
"http://decode-only:8000".to_string(),
|
||||||
|
WorkerType::Decode, // No bootstrap port
|
||||||
|
);
|
||||||
|
|
||||||
|
let mut json = json!({
|
||||||
|
"prompt": "Hello world"
|
||||||
|
});
|
||||||
|
|
||||||
|
let result = inject_bootstrap_fields(&mut json, &worker);
|
||||||
|
assert!(result.is_ok());
|
||||||
|
|
||||||
|
// Verify bootstrap fields with null port
|
||||||
|
assert_eq!(json["bootstrap_host"], json!("decode-only"));
|
||||||
|
assert_eq!(json["bootstrap_port"], json!(null));
|
||||||
|
assert!(json["bootstrap_room"].is_number());
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -11,10 +11,10 @@ use std::fmt::Debug;
|
|||||||
|
|
||||||
use crate::openai_api_types::{ChatCompletionRequest, CompletionRequest, GenerateRequest};
|
use crate::openai_api_types::{ChatCompletionRequest, CompletionRequest, GenerateRequest};
|
||||||
|
|
||||||
|
pub mod bootstrap_injector;
|
||||||
pub mod factory;
|
pub mod factory;
|
||||||
pub mod pd_router;
|
pub mod pd_router;
|
||||||
pub mod pd_types;
|
pub mod pd_types;
|
||||||
pub mod request_adapter;
|
|
||||||
pub mod router;
|
pub mod router;
|
||||||
|
|
||||||
pub use factory::RouterFactory;
|
pub use factory::RouterFactory;
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -1,10 +1,3 @@
|
|||||||
// Essential PDLB types extracted for PD routing
|
|
||||||
|
|
||||||
use crate::core::{Worker, WorkerType};
|
|
||||||
use crate::openai_api_types::{CompletionRequest, StringOrArray};
|
|
||||||
use serde::{Deserialize, Serialize};
|
|
||||||
use serde_json::Value;
|
|
||||||
|
|
||||||
// Custom error type for PD router operations
|
// Custom error type for PD router operations
|
||||||
#[derive(Debug, thiserror::Error)]
|
#[derive(Debug, thiserror::Error)]
|
||||||
pub enum PDRouterError {
|
pub enum PDRouterError {
|
||||||
@@ -58,428 +51,3 @@ pub enum PDSelectionPolicy {
|
|||||||
balance_rel_threshold: f32,
|
balance_rel_threshold: f32,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
// Bootstrap types from PDLB
|
|
||||||
#[derive(Debug, Deserialize, Serialize, PartialEq)]
|
|
||||||
#[serde(untagged)]
|
|
||||||
pub enum SingleOrBatch<T> {
|
|
||||||
Single(T),
|
|
||||||
Batch(Vec<T>),
|
|
||||||
}
|
|
||||||
|
|
||||||
pub type InputIds = SingleOrBatch<Vec<i32>>;
|
|
||||||
pub type InputText = SingleOrBatch<String>;
|
|
||||||
pub type BootstrapHost = SingleOrBatch<String>;
|
|
||||||
pub type BootstrapPort = SingleOrBatch<Option<u16>>;
|
|
||||||
pub type BootstrapRoom = SingleOrBatch<u64>;
|
|
||||||
|
|
||||||
// Bootstrap trait for request handling
|
|
||||||
pub trait Bootstrap: Send + Sync {
|
|
||||||
fn is_stream(&self) -> bool;
|
|
||||||
fn get_batch_size(&self) -> Result<Option<usize>, String>;
|
|
||||||
fn set_bootstrap_info(
|
|
||||||
&mut self,
|
|
||||||
bootstrap_host: BootstrapHost,
|
|
||||||
bootstrap_port: BootstrapPort,
|
|
||||||
bootstrap_room: BootstrapRoom,
|
|
||||||
);
|
|
||||||
|
|
||||||
fn add_bootstrap_info(&mut self, prefill_worker: &dyn Worker) -> Result<(), String> {
|
|
||||||
let batch_size = self.get_batch_size()?;
|
|
||||||
|
|
||||||
// Extract bootstrap port from prefill worker if it's a prefill type
|
|
||||||
let bootstrap_port = match prefill_worker.worker_type() {
|
|
||||||
WorkerType::Prefill { bootstrap_port } => bootstrap_port,
|
|
||||||
_ => None,
|
|
||||||
};
|
|
||||||
|
|
||||||
let hostname = get_hostname(prefill_worker.url());
|
|
||||||
|
|
||||||
if let Some(batch_size) = batch_size {
|
|
||||||
self.set_bootstrap_info(
|
|
||||||
BootstrapHost::Batch(vec![hostname; batch_size]),
|
|
||||||
BootstrapPort::Batch(vec![bootstrap_port; batch_size]),
|
|
||||||
// Use high-quality random numbers to minimize collision risk
|
|
||||||
BootstrapRoom::Batch(
|
|
||||||
(0..batch_size)
|
|
||||||
.map(|_| {
|
|
||||||
// Generate a value in the range [0, 2^63 - 1] to match Python's random.randint(0, 2**63 - 1)
|
|
||||||
rand::random::<u64>() & (i64::MAX as u64)
|
|
||||||
})
|
|
||||||
.collect(),
|
|
||||||
),
|
|
||||||
);
|
|
||||||
} else {
|
|
||||||
self.set_bootstrap_info(
|
|
||||||
BootstrapHost::Single(hostname),
|
|
||||||
BootstrapPort::Single(bootstrap_port),
|
|
||||||
BootstrapRoom::Single(
|
|
||||||
// Generate a value in the range [0, 2^63 - 1] to match Python's random.randint(0, 2**63 - 1)
|
|
||||||
rand::random::<u64>() & (i64::MAX as u64),
|
|
||||||
),
|
|
||||||
);
|
|
||||||
}
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Request types
|
|
||||||
#[derive(Debug, Deserialize, Serialize)]
|
|
||||||
pub struct GenerateReqInput {
|
|
||||||
pub text: Option<InputText>,
|
|
||||||
pub input_ids: Option<InputIds>,
|
|
||||||
#[serde(default)]
|
|
||||||
pub stream: bool,
|
|
||||||
pub bootstrap_host: Option<BootstrapHost>,
|
|
||||||
pub bootstrap_port: Option<BootstrapPort>,
|
|
||||||
pub bootstrap_room: Option<BootstrapRoom>,
|
|
||||||
|
|
||||||
#[serde(flatten)]
|
|
||||||
pub other: Value,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl GenerateReqInput {
|
|
||||||
pub fn get_batch_size(&self) -> Result<Option<usize>, String> {
|
|
||||||
if self.text.is_some() && self.input_ids.is_some() {
|
|
||||||
return Err("Both text and input_ids are present in the request".to_string());
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check text batch
|
|
||||||
if let Some(InputText::Batch(texts)) = &self.text {
|
|
||||||
if texts.is_empty() {
|
|
||||||
return Err("Batch text array is empty".to_string());
|
|
||||||
}
|
|
||||||
return Ok(Some(texts.len()));
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check input_ids batch
|
|
||||||
if let Some(InputIds::Batch(ids)) = &self.input_ids {
|
|
||||||
if ids.is_empty() {
|
|
||||||
return Err("Batch input_ids array is empty".to_string());
|
|
||||||
}
|
|
||||||
// Validate each sequence is not empty
|
|
||||||
for (i, seq) in ids.iter().enumerate() {
|
|
||||||
if seq.is_empty() {
|
|
||||||
return Err(format!("Input sequence at index {} is empty", i));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return Ok(Some(ids.len()));
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(None)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Bootstrap for GenerateReqInput {
|
|
||||||
fn is_stream(&self) -> bool {
|
|
||||||
self.stream
|
|
||||||
}
|
|
||||||
|
|
||||||
fn get_batch_size(&self) -> Result<Option<usize>, String> {
|
|
||||||
self.get_batch_size()
|
|
||||||
}
|
|
||||||
|
|
||||||
fn set_bootstrap_info(
|
|
||||||
&mut self,
|
|
||||||
bootstrap_host: BootstrapHost,
|
|
||||||
bootstrap_port: BootstrapPort,
|
|
||||||
bootstrap_room: BootstrapRoom,
|
|
||||||
) {
|
|
||||||
self.bootstrap_host = Some(bootstrap_host);
|
|
||||||
self.bootstrap_port = Some(bootstrap_port);
|
|
||||||
self.bootstrap_room = Some(bootstrap_room);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Deserialize, Serialize)]
|
|
||||||
pub struct ChatReqInput {
|
|
||||||
#[serde(default)]
|
|
||||||
pub stream: bool,
|
|
||||||
pub bootstrap_host: Option<BootstrapHost>,
|
|
||||||
pub bootstrap_port: Option<BootstrapPort>,
|
|
||||||
pub bootstrap_room: Option<BootstrapRoom>,
|
|
||||||
|
|
||||||
#[serde(flatten)]
|
|
||||||
pub other: Value,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Bootstrap for ChatReqInput {
|
|
||||||
fn is_stream(&self) -> bool {
|
|
||||||
self.stream
|
|
||||||
}
|
|
||||||
|
|
||||||
fn get_batch_size(&self) -> Result<Option<usize>, String> {
|
|
||||||
// Check if 'n' parameter is present and > 1
|
|
||||||
if let Some(n_value) = self.other.get("n") {
|
|
||||||
if let Some(n) = n_value.as_u64() {
|
|
||||||
if n > 1 {
|
|
||||||
return Ok(Some(n as usize));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Ok(None)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn set_bootstrap_info(
|
|
||||||
&mut self,
|
|
||||||
bootstrap_host: BootstrapHost,
|
|
||||||
bootstrap_port: BootstrapPort,
|
|
||||||
bootstrap_room: BootstrapRoom,
|
|
||||||
) {
|
|
||||||
self.bootstrap_host = Some(bootstrap_host);
|
|
||||||
self.bootstrap_port = Some(bootstrap_port);
|
|
||||||
self.bootstrap_room = Some(bootstrap_room);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Bootstrap implementation for CompletionRequest to preserve OpenAI format
|
|
||||||
impl Bootstrap for CompletionRequest {
|
|
||||||
fn is_stream(&self) -> bool {
|
|
||||||
self.stream
|
|
||||||
}
|
|
||||||
|
|
||||||
fn get_batch_size(&self) -> Result<Option<usize>, String> {
|
|
||||||
if let StringOrArray::Array(prompts) = &self.prompt {
|
|
||||||
if prompts.is_empty() {
|
|
||||||
return Err("Batch prompt array is empty".to_string());
|
|
||||||
}
|
|
||||||
return Ok(Some(prompts.len()));
|
|
||||||
}
|
|
||||||
|
|
||||||
// Single string prompt
|
|
||||||
Ok(None)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn set_bootstrap_info(
|
|
||||||
&mut self,
|
|
||||||
bootstrap_host: BootstrapHost,
|
|
||||||
bootstrap_port: BootstrapPort,
|
|
||||||
bootstrap_room: BootstrapRoom,
|
|
||||||
) {
|
|
||||||
// Insert bootstrap_host - it serializes correctly whether Single or Batch
|
|
||||||
if let Ok(host_value) = serde_json::to_value(&bootstrap_host) {
|
|
||||||
self.other.insert("bootstrap_host".to_string(), host_value);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Insert bootstrap_port - it serializes correctly whether Single or Batch
|
|
||||||
if let Ok(port_value) = serde_json::to_value(&bootstrap_port) {
|
|
||||||
self.other.insert("bootstrap_port".to_string(), port_value);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Insert bootstrap_room - it serializes correctly whether Single or Batch
|
|
||||||
if let Ok(room_value) = serde_json::to_value(&bootstrap_room) {
|
|
||||||
self.other.insert("bootstrap_room".to_string(), room_value);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(test)]
|
|
||||||
mod bootstrap_tests {
|
|
||||||
use super::*;
|
|
||||||
use crate::core::BasicWorker;
|
|
||||||
use crate::openai_api_types::StringOrArray;
|
|
||||||
|
|
||||||
/// Create a default CompletionRequest for testing with minimal fields set
|
|
||||||
fn default_completion_request() -> CompletionRequest {
|
|
||||||
CompletionRequest {
|
|
||||||
model: String::new(),
|
|
||||||
prompt: StringOrArray::String(String::new()),
|
|
||||||
n: None,
|
|
||||||
other: serde_json::Map::new(),
|
|
||||||
suffix: None,
|
|
||||||
max_tokens: None,
|
|
||||||
temperature: None,
|
|
||||||
top_p: None,
|
|
||||||
stream: false,
|
|
||||||
stream_options: None,
|
|
||||||
logprobs: None,
|
|
||||||
echo: false,
|
|
||||||
stop: None,
|
|
||||||
presence_penalty: None,
|
|
||||||
frequency_penalty: None,
|
|
||||||
best_of: None,
|
|
||||||
logit_bias: None,
|
|
||||||
user: None,
|
|
||||||
seed: None,
|
|
||||||
// SGLang Extensions
|
|
||||||
top_k: None,
|
|
||||||
min_p: None,
|
|
||||||
min_tokens: None,
|
|
||||||
repetition_penalty: None,
|
|
||||||
regex: None,
|
|
||||||
ebnf: None,
|
|
||||||
json_schema: None,
|
|
||||||
stop_token_ids: None,
|
|
||||||
no_stop_trim: false,
|
|
||||||
ignore_eos: false,
|
|
||||||
skip_special_tokens: true,
|
|
||||||
// SGLang Extensions
|
|
||||||
lora_path: None,
|
|
||||||
session_params: None,
|
|
||||||
return_hidden_states: false,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_completion_batch_size_with_array_prompt() {
|
|
||||||
let req = CompletionRequest {
|
|
||||||
model: "test".to_string(),
|
|
||||||
prompt: StringOrArray::Array(vec!["prompt1".to_string(), "prompt2".to_string()]),
|
|
||||||
..default_completion_request()
|
|
||||||
};
|
|
||||||
|
|
||||||
// Should return batch size for array prompt
|
|
||||||
assert_eq!(req.get_batch_size().unwrap(), Some(2));
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_completion_batch_size_with_single_prompt() {
|
|
||||||
let req = CompletionRequest {
|
|
||||||
model: "test".to_string(),
|
|
||||||
prompt: StringOrArray::String("single prompt".to_string()),
|
|
||||||
..default_completion_request()
|
|
||||||
};
|
|
||||||
|
|
||||||
// Should return None for single prompt
|
|
||||||
assert_eq!(req.get_batch_size().unwrap(), None);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_completion_batch_size_with_n_parameter() {
|
|
||||||
let req = CompletionRequest {
|
|
||||||
model: "test".to_string(),
|
|
||||||
prompt: StringOrArray::String("single prompt".to_string()),
|
|
||||||
n: Some(3),
|
|
||||||
..default_completion_request()
|
|
||||||
};
|
|
||||||
|
|
||||||
// Should return None for single string prompt, even with n > 1
|
|
||||||
// SGLang handles n parameter differently than batch requests
|
|
||||||
assert_eq!(req.get_batch_size().unwrap(), None);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_completion_bootstrap_single_values() {
|
|
||||||
let mut req = CompletionRequest {
|
|
||||||
model: "test".to_string(),
|
|
||||||
prompt: StringOrArray::Array(vec!["prompt1".to_string(), "prompt2".to_string()]),
|
|
||||||
..default_completion_request()
|
|
||||||
};
|
|
||||||
|
|
||||||
// Set bootstrap info - should always use single values
|
|
||||||
req.set_bootstrap_info(
|
|
||||||
BootstrapHost::Single("test-server".to_string()),
|
|
||||||
BootstrapPort::Single(Some(5678)),
|
|
||||||
BootstrapRoom::Single(12345),
|
|
||||||
);
|
|
||||||
|
|
||||||
// Verify single values were created
|
|
||||||
assert!(req.other.get("bootstrap_host").unwrap().is_string());
|
|
||||||
assert!(req.other.get("bootstrap_port").unwrap().is_number());
|
|
||||||
assert!(req.other.get("bootstrap_room").unwrap().is_number());
|
|
||||||
|
|
||||||
assert_eq!(
|
|
||||||
req.other.get("bootstrap_host").unwrap().as_str().unwrap(),
|
|
||||||
"test-server"
|
|
||||||
);
|
|
||||||
assert_eq!(
|
|
||||||
req.other.get("bootstrap_port").unwrap().as_u64().unwrap(),
|
|
||||||
5678
|
|
||||||
);
|
|
||||||
assert_eq!(
|
|
||||||
req.other.get("bootstrap_room").unwrap().as_u64().unwrap(),
|
|
||||||
12345
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_completion_bootstrap_array_values() {
|
|
||||||
let mut req = CompletionRequest {
|
|
||||||
model: "test".to_string(),
|
|
||||||
prompt: StringOrArray::Array(vec!["prompt1".to_string(), "prompt2".to_string()]),
|
|
||||||
..default_completion_request()
|
|
||||||
};
|
|
||||||
|
|
||||||
// Set bootstrap info with arrays
|
|
||||||
req.set_bootstrap_info(
|
|
||||||
BootstrapHost::Batch(vec!["test-server".to_string(); 2]),
|
|
||||||
BootstrapPort::Batch(vec![Some(5678); 2]),
|
|
||||||
BootstrapRoom::Batch(vec![12345, 67890]),
|
|
||||||
);
|
|
||||||
|
|
||||||
// Verify arrays were created correctly
|
|
||||||
assert!(req.other.get("bootstrap_host").unwrap().is_array());
|
|
||||||
assert!(req.other.get("bootstrap_port").unwrap().is_array());
|
|
||||||
assert!(req.other.get("bootstrap_room").unwrap().is_array());
|
|
||||||
|
|
||||||
let hosts = req.other.get("bootstrap_host").unwrap().as_array().unwrap();
|
|
||||||
assert_eq!(hosts.len(), 2);
|
|
||||||
assert_eq!(hosts[0].as_str().unwrap(), "test-server");
|
|
||||||
|
|
||||||
let ports = req.other.get("bootstrap_port").unwrap().as_array().unwrap();
|
|
||||||
assert_eq!(ports.len(), 2);
|
|
||||||
assert_eq!(ports[0].as_u64().unwrap(), 5678);
|
|
||||||
|
|
||||||
let rooms = req.other.get("bootstrap_room").unwrap().as_array().unwrap();
|
|
||||||
assert_eq!(rooms.len(), 2);
|
|
||||||
assert_eq!(rooms[0].as_u64().unwrap(), 12345);
|
|
||||||
assert_eq!(rooms[1].as_u64().unwrap(), 67890);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_bootstrap_room_range() {
|
|
||||||
// Test that bootstrap_room values are within the expected range [0, 2^63 - 1]
|
|
||||||
let worker = BasicWorker::new(
|
|
||||||
"http://test:8000".to_string(),
|
|
||||||
WorkerType::Prefill {
|
|
||||||
bootstrap_port: Some(8080),
|
|
||||||
},
|
|
||||||
);
|
|
||||||
|
|
||||||
// Test single request
|
|
||||||
let mut single_req = GenerateReqInput {
|
|
||||||
text: Some(InputText::Single("test".to_string())),
|
|
||||||
input_ids: None,
|
|
||||||
stream: false,
|
|
||||||
bootstrap_host: None,
|
|
||||||
bootstrap_port: None,
|
|
||||||
bootstrap_room: None,
|
|
||||||
other: Value::Object(serde_json::Map::new()),
|
|
||||||
};
|
|
||||||
|
|
||||||
for _ in 0..200000 {
|
|
||||||
single_req.add_bootstrap_info(&worker).unwrap();
|
|
||||||
if let Some(BootstrapRoom::Single(room)) = single_req.bootstrap_room {
|
|
||||||
// Verify the room value is within signed 64-bit range
|
|
||||||
assert!(room <= i64::MAX as u64, "Room {} exceeds i64::MAX", room);
|
|
||||||
} else {
|
|
||||||
panic!("Expected single bootstrap room");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Test batch request
|
|
||||||
let mut batch_req = GenerateReqInput {
|
|
||||||
text: Some(InputText::Batch(vec![
|
|
||||||
"test1".to_string(),
|
|
||||||
"test2".to_string(),
|
|
||||||
])),
|
|
||||||
input_ids: None,
|
|
||||||
stream: false,
|
|
||||||
bootstrap_host: None,
|
|
||||||
bootstrap_port: None,
|
|
||||||
bootstrap_room: None,
|
|
||||||
other: Value::Object(serde_json::Map::new()),
|
|
||||||
};
|
|
||||||
|
|
||||||
for _ in 0..200000 {
|
|
||||||
batch_req.add_bootstrap_info(&worker).unwrap();
|
|
||||||
if let Some(BootstrapRoom::Batch(rooms)) = &batch_req.bootstrap_room {
|
|
||||||
for room in rooms {
|
|
||||||
// Verify each room value is within signed 64-bit range
|
|
||||||
assert!(*room <= i64::MAX as u64, "Room {} exceeds i64::MAX", room);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
panic!("Expected batch bootstrap rooms");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -1,12 +1,15 @@
|
|||||||
// Integration test to ensure benchmarks compile and basic functionality works
|
// Integration test to ensure benchmarks compile and basic functionality works
|
||||||
// This prevents benchmarks from breaking in CI
|
// This prevents benchmarks from breaking in CI
|
||||||
|
//
|
||||||
|
// UPDATED: Removed deprecated ToPdRequest usage, now uses direct JSON serialization
|
||||||
|
|
||||||
use serde_json::{from_str, to_string};
|
use serde_json::{from_str, to_string, to_value};
|
||||||
|
use sglang_router_rs::core::{BasicWorker, WorkerType};
|
||||||
use sglang_router_rs::openai_api_types::{
|
use sglang_router_rs::openai_api_types::{
|
||||||
ChatCompletionRequest, ChatMessage, CompletionRequest, GenerateParameters, GenerateRequest,
|
ChatCompletionRequest, ChatMessage, CompletionRequest, GenerateParameters, GenerateRequest,
|
||||||
SamplingParams, StringOrArray, UserMessageContent,
|
SamplingParams, StringOrArray, UserMessageContent,
|
||||||
};
|
};
|
||||||
use sglang_router_rs::routers::request_adapter::{RouteableRequest, ToPdRequest};
|
use sglang_router_rs::routers::bootstrap_injector::inject_bootstrap_fields;
|
||||||
|
|
||||||
/// Create a default GenerateRequest for benchmarks with minimal fields set
|
/// Create a default GenerateRequest for benchmarks with minimal fields set
|
||||||
fn default_generate_request() -> GenerateRequest {
|
fn default_generate_request() -> GenerateRequest {
|
||||||
@@ -114,6 +117,15 @@ fn default_completion_request() -> CompletionRequest {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn create_test_worker() -> BasicWorker {
|
||||||
|
BasicWorker::new(
|
||||||
|
"http://test-server:8000".to_string(),
|
||||||
|
WorkerType::Prefill {
|
||||||
|
bootstrap_port: Some(5678),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_benchmark_request_creation() {
|
fn test_benchmark_request_creation() {
|
||||||
// Ensure all benchmark request types can be created without panicking
|
// Ensure all benchmark request types can be created without panicking
|
||||||
@@ -197,8 +209,8 @@ fn test_benchmark_serialization_roundtrip() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_benchmark_request_adaptation() {
|
fn test_benchmark_bootstrap_injection() {
|
||||||
// Test that PD request adaptation works for benchmark types
|
// Test that bootstrap injection works for benchmark types (replaces PD request adaptation)
|
||||||
|
|
||||||
let generate_req = GenerateRequest {
|
let generate_req = GenerateRequest {
|
||||||
text: Some("Test prompt".to_string()),
|
text: Some("Test prompt".to_string()),
|
||||||
@@ -236,24 +248,40 @@ fn test_benchmark_request_adaptation() {
|
|||||||
..default_completion_request()
|
..default_completion_request()
|
||||||
};
|
};
|
||||||
|
|
||||||
// Test PD adaptation (should not panic)
|
let worker = create_test_worker();
|
||||||
let _pd_generate = generate_req.to_pd_request();
|
|
||||||
let _pd_chat = chat_req.to_pd_request();
|
// Test bootstrap injection (should not panic)
|
||||||
let _pd_completion = completion_req.to_pd_request();
|
let mut generate_json = to_value(&generate_req).unwrap();
|
||||||
|
let mut chat_json = to_value(&chat_req).unwrap();
|
||||||
|
let mut completion_json = to_value(&completion_req).unwrap();
|
||||||
|
|
||||||
|
assert!(inject_bootstrap_fields(&mut generate_json, &worker).is_ok());
|
||||||
|
assert!(inject_bootstrap_fields(&mut chat_json, &worker).is_ok());
|
||||||
|
assert!(inject_bootstrap_fields(&mut completion_json, &worker).is_ok());
|
||||||
|
|
||||||
|
// Verify bootstrap fields were added
|
||||||
|
assert!(generate_json.get("bootstrap_host").is_some());
|
||||||
|
assert!(generate_json.get("bootstrap_port").is_some());
|
||||||
|
assert!(generate_json.get("bootstrap_room").is_some());
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_benchmark_regular_routing() {
|
fn test_benchmark_direct_json_routing() {
|
||||||
// Test regular routing functionality for benchmark types
|
// Test direct JSON routing functionality for benchmark types (replaces regular routing)
|
||||||
|
|
||||||
let generate_req = GenerateRequest {
|
let generate_req = GenerateRequest {
|
||||||
text: Some("Test prompt".to_string()),
|
text: Some("Test prompt".to_string()),
|
||||||
..default_generate_request()
|
..default_generate_request()
|
||||||
};
|
};
|
||||||
|
|
||||||
// Test regular routing methods (should not panic)
|
// Test direct JSON conversion (replaces regular routing methods)
|
||||||
let _json = generate_req.to_json();
|
let json = to_value(&generate_req).unwrap();
|
||||||
let _bytes = generate_req.to_bytes();
|
let json_string = to_string(&json).unwrap();
|
||||||
|
let bytes = json_string.as_bytes();
|
||||||
|
|
||||||
|
// Verify conversions work
|
||||||
|
assert!(!json_string.is_empty());
|
||||||
|
assert!(!bytes.is_empty());
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@@ -266,23 +294,36 @@ fn test_benchmark_performance_baseline() {
|
|||||||
..default_generate_request()
|
..default_generate_request()
|
||||||
};
|
};
|
||||||
|
|
||||||
// Serialization should be fast (< 1ms for simple requests)
|
// Test the actual simplified pipeline: to_value + bootstrap injection
|
||||||
let start = Instant::now();
|
let start = Instant::now();
|
||||||
let _json = to_string(&generate_req).unwrap();
|
let worker = create_test_worker();
|
||||||
let serialize_duration = start.elapsed();
|
|
||||||
|
// This mirrors the actual router pipeline
|
||||||
|
let mut json = to_value(&generate_req).unwrap();
|
||||||
|
let _ = inject_bootstrap_fields(&mut json, &worker);
|
||||||
|
|
||||||
|
let total_duration = start.elapsed();
|
||||||
assert!(
|
assert!(
|
||||||
serialize_duration.as_millis() < 1,
|
total_duration.as_millis() < 5,
|
||||||
"Serialization took too long: {:?}",
|
"Simplified pipeline took too long: {:?} (should be faster than old adapter approach)",
|
||||||
serialize_duration
|
total_duration
|
||||||
);
|
);
|
||||||
|
|
||||||
// PD adaptation should be very fast (< 1ms)
|
// Individual components should also be fast
|
||||||
let start = Instant::now();
|
let start = Instant::now();
|
||||||
let _pd_req = generate_req.to_pd_request();
|
let _json = to_value(&generate_req).unwrap();
|
||||||
let adapt_duration = start.elapsed();
|
let to_value_duration = start.elapsed();
|
||||||
|
|
||||||
|
let start = Instant::now();
|
||||||
|
let mut json = to_value(&generate_req).unwrap();
|
||||||
|
let _ = inject_bootstrap_fields(&mut json, &worker);
|
||||||
|
let inject_duration = start.elapsed();
|
||||||
|
|
||||||
|
// Bootstrap injection should be faster than the JSON conversion
|
||||||
assert!(
|
assert!(
|
||||||
adapt_duration.as_millis() < 1,
|
inject_duration <= to_value_duration * 3,
|
||||||
"PD adaptation took too long: {:?}",
|
"Bootstrap injection ({:?}) should not be much slower than JSON conversion ({:?})",
|
||||||
adapt_duration
|
inject_duration,
|
||||||
|
to_value_duration
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user