Files
sglang/sgl-router/src/routers/http/pd_router.rs

1393 lines
47 KiB
Rust

use std::{sync::Arc, time::Instant};
use async_trait::async_trait;
use axum::{
body::Body,
extract::Request,
http::{header::CONTENT_TYPE, HeaderMap, HeaderValue, StatusCode},
response::{IntoResponse, Response},
};
use futures_util::StreamExt;
use reqwest::Client;
use serde::Serialize;
use serde_json::{json, Value};
use tokio_stream::wrappers::UnboundedReceiverStream;
use tracing::{debug, error, warn};
use super::pd_types::api_path;
use crate::{
config::types::RetryConfig,
core::{
is_retryable_status, RetryExecutor, Worker, WorkerLoadGuard, WorkerRegistry, WorkerType,
},
metrics::RouterMetrics,
policies::{LoadBalancingPolicy, PolicyRegistry},
protocols::{
chat::{ChatCompletionRequest, ChatMessage, UserMessageContent},
common::{InputIds, StringOrArray},
completion::CompletionRequest,
embedding::EmbeddingRequest,
generate::GenerateRequest,
rerank::RerankRequest,
responses::{ResponsesGetParams, ResponsesRequest},
},
routers::{header_utils, RouterTrait},
};
#[derive(Debug)]
pub struct PDRouter {
pub worker_registry: Arc<WorkerRegistry>,
pub policy_registry: Arc<PolicyRegistry>,
pub client: Client,
pub retry_config: RetryConfig,
pub api_key: Option<String>,
pub enable_igw: bool,
}
#[derive(Clone)]
struct PDRequestContext<'a> {
route: &'static str,
batch_size: Option<usize>,
is_stream: bool,
return_logprob: bool,
request_text: Option<String>,
model_id: Option<&'a str>,
}
impl PDRouter {
async fn proxy_to_first_prefill_worker(
&self,
endpoint: &str,
headers: Option<Vec<(String, String)>>,
) -> Response {
let workers = self.worker_registry.get_prefill_workers();
let first_worker_url = workers.first().map(|w| w.url().to_string());
if let Some(worker_url) = first_worker_url {
self.proxy_to_worker(worker_url, endpoint, headers).await
} else {
(
StatusCode::SERVICE_UNAVAILABLE,
"No prefill servers available".to_string(),
)
.into_response()
}
}
async fn proxy_to_worker(
&self,
worker_url: String,
endpoint: &str,
headers: Option<Vec<(String, String)>>,
) -> Response {
let url = format!("{}/{}", worker_url, endpoint);
let mut request_builder = self.client.get(&url);
if let Some(headers) = headers {
for (name, value) in headers {
request_builder = request_builder.header(name, value);
}
}
match request_builder.send().await {
Ok(res) if res.status().is_success() => {
let response_headers = header_utils::preserve_response_headers(res.headers());
match res.bytes().await {
Ok(body) => {
let mut response = Response::new(Body::from(body));
*response.status_mut() = StatusCode::OK;
*response.headers_mut() = response_headers;
response
}
Err(e) => {
error!("Failed to read response body: {}", e);
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Failed to read response body: {}", e),
)
.into_response()
}
}
}
Ok(res) => {
let status = StatusCode::from_u16(res.status().as_u16())
.unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
(status, format!("{} server returned status: ", res.status())).into_response()
}
Err(e) => {
error!("Failed to proxy request server: {}", e);
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Failed to proxy request: {}", e),
)
.into_response()
}
}
}
pub async fn new(ctx: &Arc<crate::server::AppContext>) -> Result<Self, String> {
Ok(PDRouter {
worker_registry: Arc::clone(&ctx.worker_registry),
policy_registry: Arc::clone(&ctx.policy_registry),
client: ctx.client.clone(),
retry_config: ctx.router_config.effective_retry_config(),
api_key: ctx.router_config.api_key.clone(),
enable_igw: ctx.router_config.enable_igw,
})
}
fn handle_server_selection_error(error: String) -> Response {
error!("Failed to select PD pair error={}", error);
RouterMetrics::record_pd_error("server_selection");
(
StatusCode::SERVICE_UNAVAILABLE,
format!("No available servers: {}", error),
)
.into_response()
}
fn handle_serialization_error(error: impl std::fmt::Display) -> Response {
error!("Failed to serialize request error={}", error);
(
StatusCode::INTERNAL_SERVER_ERROR,
"Failed to serialize request",
)
.into_response()
}
fn get_generate_batch_size(req: &GenerateRequest) -> Option<usize> {
// GenerateRequest doesn't support batch via arrays, only via input_ids
if let Some(InputIds::Batch(batches)) = &req.input_ids {
if !batches.is_empty() {
return Some(batches.len());
}
}
None
}
fn get_chat_batch_size(req: &ChatCompletionRequest) -> Option<usize> {
if let Some(n) = req.n {
if n > 1 {
return Some(n as usize);
}
}
None
}
fn get_completion_batch_size(req: &CompletionRequest) -> Option<usize> {
if let StringOrArray::Array(arr) = &req.prompt {
if !arr.is_empty() {
return Some(arr.len());
}
}
None
}
fn inject_bootstrap_into_value(
mut original: Value,
prefill_worker: &dyn Worker,
batch_size: Option<usize>,
) -> Result<Value, String> {
let obj = original
.as_object_mut()
.ok_or_else(|| "Request must be a JSON object".to_string())?;
if let Some(n) = batch_size {
let mut hosts = Vec::with_capacity(n);
let mut ports = Vec::with_capacity(n);
let mut rooms = Vec::with_capacity(n);
for _ in 0..n {
hosts.push(prefill_worker.bootstrap_host());
ports.push(prefill_worker.bootstrap_port());
rooms.push(super::pd_types::generate_room_id());
}
obj.insert(
"bootstrap_host".to_string(),
Value::Array(hosts.into_iter().map(Value::from).collect()),
);
obj.insert(
"bootstrap_port".to_string(),
Value::Array(
ports
.into_iter()
.map(|p| match p {
Some(v) => Value::from(v),
None => Value::Null,
})
.collect(),
),
);
obj.insert(
"bootstrap_room".to_string(),
Value::Array(rooms.into_iter().map(Value::from).collect()),
);
} else {
obj.insert(
"bootstrap_host".to_string(),
Value::from(prefill_worker.bootstrap_host()),
);
obj.insert(
"bootstrap_port".to_string(),
match prefill_worker.bootstrap_port() {
Some(v) => Value::from(v),
None => Value::Null,
},
);
obj.insert(
"bootstrap_room".to_string(),
Value::from(super::pd_types::generate_room_id()),
);
}
Ok(original)
}
async fn execute_dual_dispatch<T: Serialize + Clone>(
&self,
headers: Option<&HeaderMap>,
original_request: &T,
context: PDRequestContext<'_>,
) -> Response {
let start_time = Instant::now();
let route = context.route;
RetryExecutor::execute_response_with_retry(
&self.retry_config,
{
let original_request = original_request.clone();
move |attempt: u32| {
let original_request = original_request.clone();
let context = context.clone();
async move {
let (prefill, decode) = match self
.select_pd_pair(context.request_text.as_deref(), context.model_id)
.await
{
Ok(pair) => pair,
Err(e) => {
RouterMetrics::record_pd_error("server_selection");
return Self::handle_server_selection_error(e);
}
};
debug!(
"PD retry attempt {} using prefill={} decode={}",
attempt,
prefill.url(),
decode.url()
);
let mut json_request = match serde_json::to_value(&original_request) {
Ok(v) => v,
Err(e) => return Self::handle_serialization_error(e),
};
json_request = match Self::inject_bootstrap_into_value(
json_request,
prefill.as_ref(),
context.batch_size,
) {
Ok(v) => v,
Err(e) => return Self::handle_serialization_error(e),
};
let response = self
.execute_dual_dispatch_internal(
headers,
json_request,
context,
prefill.as_ref(),
decode.as_ref(),
start_time,
)
.await;
let _status = response.status();
let not_error = _status.is_success() || _status.is_client_error();
prefill.record_outcome(not_error);
decode.record_outcome(not_error);
response
}
}
},
|res, _attempt| is_retryable_status(res.status()),
|delay, attempt| {
RouterMetrics::record_retry(route);
RouterMetrics::record_retry_backoff_duration(delay, attempt);
},
|| RouterMetrics::record_retries_exhausted(route),
)
.await
}
async fn handle_decode_error_response(
&self,
res: reqwest::Response,
context: &PDRequestContext<'_>,
prefill: &dyn Worker,
decode: &dyn Worker,
) -> Response {
let status = res.status();
if context.is_stream {
// Handle streaming error response
let response_headers = header_utils::preserve_response_headers(res.headers());
let error_payload = match res.bytes().await {
Ok(error_body) => {
if let Ok(error_json) = serde_json::from_slice::<Value>(&error_body) {
json!({ "message": error_json, "status": status.as_u16() })
} else {
json!({ "message": String::from_utf8_lossy(&error_body).to_string(), "status": status.as_u16() })
}
}
Err(e) => {
json!({ "message": format!("Decode server error: {}", e), "status": status.as_u16() })
}
};
let sse_data = format!(
"data: {{'error': {}}}",
serde_json::to_string(&error_payload).unwrap_or_default()
);
let error_stream = tokio_stream::once(Ok(axum::body::Bytes::from(sse_data)));
let decode_url = decode.url().to_string();
self.create_streaming_response(
error_stream,
status,
None,
context.return_logprob,
Some(decode_url),
Some(response_headers),
prefill,
decode,
)
} else {
// Handle non-streaming error response
match res.bytes().await {
Ok(error_body) => (status, error_body).into_response(),
Err(e) => (status, format!("Decode server error: {}", e)).into_response(),
}
}
}
// Internal method that performs the actual dual dispatch (without retry logic)
async fn execute_dual_dispatch_internal(
&self,
headers: Option<&HeaderMap>,
json_request: Value,
context: PDRequestContext<'_>,
prefill: &dyn Worker,
decode: &dyn Worker,
start_time: Instant,
) -> Response {
// For non-streaming: use guard for automatic load management
// For streaming: load will be managed in create_streaming_response
let _guard = if !context.is_stream {
Some(WorkerLoadGuard::new_multi(vec![prefill, decode]))
} else {
None
};
// Build both requests
let prefill_request = self.build_post_with_headers(
&self.client,
prefill.url(),
context.route,
&json_request,
headers,
false,
);
let decode_request = self.build_post_with_headers(
&self.client,
decode.url(),
context.route,
&json_request,
headers,
false,
);
// Send both requests concurrently and wait for both
debug!(
"Sending concurrent requests to prefill={} decode={}",
prefill.url(),
decode.url()
);
let (prefill_result, decode_result) =
tokio::join!(prefill_request.send(), decode_request.send());
debug!("Received responses from both servers");
let duration = start_time.elapsed();
RouterMetrics::record_pd_request_duration(context.route, duration);
RouterMetrics::record_pd_request(context.route);
RouterMetrics::record_pd_prefill_request(prefill.url());
RouterMetrics::record_pd_decode_request(decode.url());
// Process decode response
match decode_result {
Ok(res) => {
let status = StatusCode::from_u16(res.status().as_u16())
.unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
debug!("Decode response status: {}", status);
if !status.is_success() {
RouterMetrics::record_pd_decode_error(decode.url());
error!(
"Decode server returned error status decode_url={} status={}",
decode.url(),
status
);
return self
.handle_decode_error_response(res, &context, prefill, decode)
.await;
}
// Process prefill response
let prefill_body = if context.return_logprob {
match self
.process_prefill_response(
prefill_result,
prefill.url(),
context.return_logprob,
)
.await
{
Ok((_, body)) => body,
Err(error_response) => return error_response,
}
} else {
// Even if we don't need logprobs, we should check prefill status
match self
.process_prefill_response(prefill_result, prefill.url(), false)
.await
{
Ok((_, body)) => body,
Err(error_response) => return error_response,
}
};
if context.is_stream {
// Streaming response
let prefill_logprobs = if context.return_logprob {
prefill_body
.as_ref()
.and_then(|body| serde_json::from_slice::<Value>(body).ok())
.and_then(|json| {
json.pointer("/meta_info/input_token_logprobs").cloned()
})
} else {
None
};
let response_headers = header_utils::preserve_response_headers(res.headers());
self.create_streaming_response(
res.bytes_stream(),
status,
prefill_logprobs,
context.return_logprob,
None,
Some(response_headers),
prefill,
decode,
)
} else {
// Non-streaming response
if context.return_logprob {
self.process_non_streaming_response(
res,
status,
context.return_logprob,
prefill_body,
)
.await
} else {
// Direct passthrough when no logprobs needed
let response_headers =
header_utils::preserve_response_headers(res.headers());
match res.bytes().await {
Ok(decode_body) => {
let mut response = Response::new(Body::from(decode_body));
*response.status_mut() = status;
*response.headers_mut() = response_headers;
response
}
Err(e) => {
error!("Failed to read decode response: {}", e);
(StatusCode::INTERNAL_SERVER_ERROR, "Failed to read response")
.into_response()
}
}
}
}
}
Err(e) => {
error!(
decode_url = %decode.url(),
error = %e,
"Decode request failed"
);
RouterMetrics::record_pd_decode_error(decode.url());
(
StatusCode::BAD_GATEWAY,
format!("Decode server error: {}", e),
)
.into_response()
}
}
}
fn policies_need_request_text(&self) -> bool {
let prefill_policy = self.policy_registry.get_prefill_policy();
let decode_policy = self.policy_registry.get_decode_policy();
prefill_policy.needs_request_text() || decode_policy.needs_request_text()
}
async fn select_pd_pair(
&self,
request_text: Option<&str>,
model_id: Option<&str>,
) -> Result<(Arc<dyn Worker>, Arc<dyn Worker>), String> {
let effective_model_id = if !self.enable_igw { None } else { model_id };
debug!(
"Selecting PD pair: enable_igw={}, model_id={:?}, effective_model_id={:?}",
self.enable_igw, model_id, effective_model_id
);
let prefill_workers = if let Some(model) = effective_model_id {
self.worker_registry
.get_by_model_fast(model)
.into_iter()
.filter(|w| matches!(w.worker_type(), WorkerType::Prefill { .. }))
.collect()
} else {
self.worker_registry.get_prefill_workers()
};
let decode_workers = if let Some(model) = effective_model_id {
self.worker_registry
.get_by_model_fast(model)
.into_iter()
.filter(|w| matches!(w.worker_type(), WorkerType::Decode))
.collect()
} else {
self.worker_registry.get_decode_workers()
};
let prefill_policy = self.policy_registry.get_prefill_policy();
let decode_policy = self.policy_registry.get_decode_policy();
let prefill = Self::pick_worker_by_policy_arc(
&prefill_workers,
&*prefill_policy,
request_text,
"prefill",
)?;
let decode = Self::pick_worker_by_policy_arc(
&decode_workers,
&*decode_policy,
request_text,
"decode",
)?;
Ok((prefill, decode))
}
fn pick_worker_by_policy_arc(
workers: &[Arc<dyn Worker>],
policy: &dyn LoadBalancingPolicy,
request_text: Option<&str>,
worker_type: &str,
) -> Result<Arc<dyn Worker>, String> {
if workers.is_empty() {
return Err(format!(
"No {} workers available. Please check if {} servers are configured and healthy.",
worker_type, worker_type
));
}
let available_workers: Vec<Arc<dyn Worker>> = workers
.iter()
.filter(|w| w.is_available())
.cloned()
.collect();
if available_workers.is_empty() {
return Err(format!(
"No available {} workers (all circuits open or unhealthy)",
worker_type
));
}
let selected_idx = policy
.select_worker(&available_workers, request_text)
.ok_or_else(|| {
format!(
"Policy {} failed to select a {} worker",
policy.name(),
worker_type
)
})?;
Ok(available_workers[selected_idx].clone())
}
#[allow(clippy::too_many_arguments)]
fn create_streaming_response(
&self,
stream: impl futures_util::Stream<Item = Result<bytes::Bytes, reqwest::Error>> + Send + 'static,
status: StatusCode,
prefill_logprobs: Option<Value>,
return_logprob: bool,
decode_url: Option<String>,
headers: Option<HeaderMap>,
prefill: &dyn Worker,
decode: &dyn Worker,
) -> Response {
prefill.increment_load();
decode.increment_load();
let prefill_url = prefill.url().to_string();
let decode_url_str = decode.url().to_string();
let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
let registry = self.worker_registry.clone();
tokio::spawn(async move {
let mut stream_completed = false;
futures_util::pin_mut!(stream);
while let Some(chunk_result) = stream.next().await {
match chunk_result {
Ok(chunk) => {
let is_done = chunk
.as_ref()
.windows(12)
.any(|window| window == b"data: [DONE]");
let result = if return_logprob && prefill_logprobs.is_some() {
Self::merge_streaming_logprobs(prefill_logprobs.clone(), &chunk)
.unwrap_or(chunk)
} else {
chunk
};
if tx.send(Ok(result)).is_err() {
break;
}
if is_done {
stream_completed = true;
break;
}
}
Err(e) => {
if let Some(ref url) = decode_url {
error!("Stream error from decode server {}: {}", url, e);
RouterMetrics::record_pd_stream_error(url);
}
let _ = tx.send(Err(format!("Stream error: {}", e)));
break;
}
}
}
if let Some(worker) = registry.get_by_url(&prefill_url) {
worker.decrement_load();
debug!(
"Decremented load for prefill worker: {} (stream_completed: {})",
prefill_url, stream_completed
);
}
if let Some(worker) = registry.get_by_url(&decode_url_str) {
worker.decrement_load();
debug!(
"Decremented load for decode worker: {} (stream_completed: {})",
decode_url_str, stream_completed
);
}
});
let stream = UnboundedReceiverStream::new(rx);
let body = Body::from_stream(stream);
let mut response = Response::new(body);
*response.status_mut() = status;
let mut headers = headers.unwrap_or_default();
headers.insert(CONTENT_TYPE, HeaderValue::from_static("text/event-stream"));
*response.headers_mut() = headers;
response
}
// Helper to process non-streaming decode response with logprob merging
async fn process_non_streaming_response(
&self,
res: reqwest::Response,
status: StatusCode,
return_logprob: bool,
prefill_body: Option<bytes::Bytes>,
) -> Response {
let response = res.bytes().await;
let decode_body = match response {
Ok(decode_body) => decode_body,
Err(e) => {
error!("Failed to read decode response: {}", e);
return (StatusCode::INTERNAL_SERVER_ERROR, "Failed to read response")
.into_response();
}
};
if !return_logprob {
return (status, decode_body).into_response();
}
let Some(prefill_body) = prefill_body else {
return (status, decode_body).into_response();
};
// Merge logprobs from prefill and decode
let (Ok(prefill_json), Ok(mut decode_json)) = (
serde_json::from_slice::<Value>(&prefill_body),
serde_json::from_slice::<Value>(&decode_body),
) else {
warn!("Failed to parse responses for logprob merging");
return (status, decode_body).into_response();
};
Self::merge_logprobs_in_json(&prefill_json, &mut decode_json);
// Return merged response
match serde_json::to_vec(&decode_json) {
Ok(body) => (status, body).into_response(),
Err(e) => {
error!("Failed to serialize merged response: {}", e);
(status, decode_body).into_response()
}
}
}
// Helper to process prefill response and extract body if needed for logprobs
async fn process_prefill_response(
&self,
prefill_result: Result<reqwest::Response, reqwest::Error>,
prefill_url: &str,
return_logprob: bool,
) -> Result<(StatusCode, Option<bytes::Bytes>), Response> {
// Check prefill result first - it's critical for disaggregated mode
let prefill_response = match prefill_result {
Ok(response) => response,
Err(e) => {
RouterMetrics::record_pd_prefill_error(prefill_url);
error!(
"Prefill server failed (CRITICAL) prefill_url={} error={}. Decode will timeout without prefill KV cache.",
prefill_url,
e
);
// Return error immediately - don't wait for decode to timeout
return Err((
StatusCode::BAD_GATEWAY,
format!(
"Prefill server error: {}. This will cause decode timeout.",
e
),
)
.into_response());
}
};
let prefill_status = StatusCode::from_u16(prefill_response.status().as_u16())
.unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
// Check if prefill succeeded
if !prefill_status.is_success() {
RouterMetrics::record_pd_prefill_error(prefill_url);
// Get error body from prefill
let error_msg = prefill_response
.text()
.await
.unwrap_or_else(|_| "Unknown prefill error".to_string());
error!(
"Prefill server returned error status prefill_url={} status={} body={}",
prefill_url, prefill_status, error_msg
);
return Err((
prefill_status,
format!("Prefill server error ({}): {}", prefill_status, error_msg),
)
.into_response());
}
// Read prefill body if needed for logprob merging
let prefill_body = if return_logprob {
match prefill_response.bytes().await {
Ok(body) => Some(body),
Err(e) => {
warn!("Failed to read prefill response body for logprobs: {}", e);
None
}
}
} else {
// For non-logprob requests, just consume the response without storing
debug!("Consuming prefill response body (non-logprob request)");
match prefill_response.bytes().await {
Ok(_) => debug!("Prefill response consumed successfully"),
Err(e) => warn!("Error consuming prefill response: {}", e),
}
None
};
Ok((prefill_status, prefill_body))
}
fn build_post_with_headers(
&self,
client: &Client,
url: &str,
route: &str,
json_request: &Value,
headers: Option<&HeaderMap>,
connection_close: bool,
) -> reqwest::RequestBuilder {
let mut request = client.post(api_path(url, route)).json(json_request);
if connection_close {
request = request.header("Connection", "close");
}
if let Some(headers) = headers {
for (name, value) in headers.iter() {
let name_lc = name.as_str().to_ascii_lowercase();
// Whitelist important end-to-end headers, skip hop-by-hop
let forward = matches!(
name_lc.as_str(),
"authorization" | "x-request-id" | "x-correlation-id"
) || name_lc.starts_with("x-request-id-");
if forward {
if let Ok(val) = value.to_str() {
request = request.header(name, val);
}
}
}
}
request
}
// Helper to merge logprobs from prefill and decode responses
fn merge_logprobs_in_json(prefill_json: &Value, decode_json: &mut Value) -> bool {
if let (Some(prefill_meta), Some(decode_meta)) = (
prefill_json.get("meta_info"),
decode_json.get_mut("meta_info"),
) {
if let (Some(prefill_logprobs), Some(decode_logprobs)) = (
prefill_meta.get("input_token_logprobs"),
decode_meta.get_mut("input_token_logprobs"),
) {
if let (Some(prefill_arr), Some(decode_arr)) =
(prefill_logprobs.as_array(), decode_logprobs.as_array_mut())
{
let mut merged = prefill_arr.clone();
merged.extend(decode_arr.clone());
decode_meta["input_token_logprobs"] = Value::Array(merged);
return true;
}
}
}
false
}
// Simple helper to merge logprobs in streaming responses
fn merge_streaming_logprobs(
prefill_logprobs: Option<Value>,
decode_chunk: &[u8],
) -> Result<bytes::Bytes, ()> {
// Skip non-data chunks
let chunk_str = std::str::from_utf8(decode_chunk).map_err(|_| ())?;
if !chunk_str.starts_with("data: ") || chunk_str.contains("[DONE]") {
return Err(());
}
// Parse JSON from chunk
let json_str = chunk_str.trim_start_matches("data: ").trim();
let mut decode_json: Value = serde_json::from_str(json_str).map_err(|_| ())?;
// Merge prefill logprobs if available
if let Some(ref p_logprobs) = prefill_logprobs {
if let Some(meta) = decode_json.get_mut("meta_info") {
if let Some(d_logprobs) = meta.get_mut("input_token_logprobs") {
if let (Some(p_arr), Some(d_arr)) =
(p_logprobs.as_array(), d_logprobs.as_array())
{
let mut merged = p_arr.clone();
merged.extend(d_arr.clone());
*d_logprobs = Value::Array(merged);
}
}
}
}
// Re-serialize
let merged_str = format!(
"data: {}\n\n",
serde_json::to_string(&decode_json).unwrap_or_default()
);
Ok(bytes::Bytes::from(merged_str))
}
}
#[async_trait]
impl RouterTrait for PDRouter {
fn as_any(&self) -> &dyn std::any::Any {
self
}
async fn health_generate(&self, _req: Request<Body>) -> Response {
// Note: This endpoint actually causes the model to generate tokens, so we only test one pair
// Select a random worker pair using the policy
let (prefill, decode) = match self.select_pd_pair(None, None).await {
Ok(pair) => pair,
Err(e) => {
return (
StatusCode::SERVICE_UNAVAILABLE,
format!("No healthy worker pair available: {}", e),
)
.into_response();
}
};
let prefill_url = format!("{}/health_generate", prefill.url());
let (prefill_result, decode_result) = tokio::join!(
self.client.get(&prefill_url).send(),
self.client
.get(format!("{}/health_generate", decode.url()))
.send()
);
// Check results
let mut errors = Vec::new();
match prefill_result {
Ok(res) if res.status().is_success() => {
debug!(
"Health generate passed for prefill server: {}",
prefill.url()
);
}
Ok(res) => {
errors.push(format!(
"Prefill {} returned status {}",
prefill.url(),
res.status()
));
}
Err(e) => {
errors.push(format!("Prefill {} error: {}", prefill.url(), e));
}
}
match decode_result {
Ok(res) if res.status().is_success() => {
debug!("Health generate passed for decode server: {}", decode.url());
}
Ok(res) => {
errors.push(format!(
"Decode {} returned status {}",
decode.url(),
res.status()
));
}
Err(e) => {
errors.push(format!("Decode {} error: {}", decode.url(), e));
}
}
if errors.is_empty() {
(
StatusCode::OK,
format!(
"Health generate passed on selected pair: prefill={}, decode={}",
prefill.url(),
decode.url()
),
)
.into_response()
} else {
(
StatusCode::SERVICE_UNAVAILABLE,
format!("Health generate failed: {:?}", errors),
)
.into_response()
}
}
async fn get_server_info(&self, _req: Request<Body>) -> Response {
// Get info from the first decode server to match sglang's server info format
// Note: We use decode workers for server info to match expected format
self.proxy_to_first_prefill_worker("get_server_info", None)
.await
}
async fn get_models(&self, req: Request<Body>) -> Response {
// Extract headers first to avoid Send issues
let headers = header_utils::copy_request_headers(&req);
// Proxy to first prefill worker
self.proxy_to_first_prefill_worker("v1/models", Some(headers))
.await
}
async fn get_model_info(&self, req: Request<Body>) -> Response {
// Extract headers first to avoid Send issues
let headers = header_utils::copy_request_headers(&req);
// Proxy to first prefill worker
self.proxy_to_first_prefill_worker("get_model_info", Some(headers))
.await
}
async fn route_generate(
&self,
headers: Option<&HeaderMap>,
body: &GenerateRequest,
model_id: Option<&str>,
) -> Response {
let is_stream = body.stream;
let return_logprob = body.return_logprob.unwrap_or(false);
let request_text = if self.policies_need_request_text() {
body.text.as_deref().map(|s| s.to_string())
} else {
None
};
let batch_size = Self::get_generate_batch_size(body);
let context = PDRequestContext {
route: "/generate",
batch_size,
is_stream,
return_logprob,
request_text,
model_id,
};
self.execute_dual_dispatch(headers, body, context).await
}
async fn route_chat(
&self,
headers: Option<&HeaderMap>,
body: &ChatCompletionRequest,
model_id: Option<&str>,
) -> Response {
let is_stream = body.stream;
let return_logprob = body.logprobs;
let request_text = if self.policies_need_request_text() {
body.messages.first().and_then(|msg| match msg {
ChatMessage::User { content, .. } => match content {
UserMessageContent::Text(text) => Some(text.clone()),
UserMessageContent::Parts(_) => None,
},
ChatMessage::System { content, .. } => Some(content.clone()),
_ => None,
})
} else {
None
};
// Calculate batch size
let batch_size = Self::get_chat_batch_size(body);
let context = PDRequestContext {
route: "/v1/chat/completions",
batch_size,
is_stream,
return_logprob,
request_text,
model_id,
};
self.execute_dual_dispatch(headers, body, context).await
}
async fn route_completion(
&self,
headers: Option<&HeaderMap>,
body: &CompletionRequest,
model_id: Option<&str>,
) -> Response {
let is_stream = body.stream;
let return_logprob = body.logprobs.is_some();
let request_text = if self.policies_need_request_text() {
match &body.prompt {
StringOrArray::String(s) => Some(s.clone()),
StringOrArray::Array(v) => v.first().map(|s| s.to_string()),
}
} else {
None
};
// Calculate batch size
let batch_size = Self::get_completion_batch_size(body);
let context = PDRequestContext {
route: "/v1/completions",
batch_size,
is_stream,
return_logprob,
request_text,
model_id,
};
self.execute_dual_dispatch(headers, body, context).await
}
async fn route_responses(
&self,
_headers: Option<&HeaderMap>,
_body: &ResponsesRequest,
_model_id: Option<&str>,
) -> Response {
(
StatusCode::NOT_IMPLEMENTED,
"Responses endpoint not implemented for PD router",
)
.into_response()
}
async fn get_response(
&self,
_headers: Option<&HeaderMap>,
_response_id: &str,
_params: &ResponsesGetParams,
) -> Response {
(
StatusCode::NOT_IMPLEMENTED,
"Responses retrieve endpoint not implemented for PD router",
)
.into_response()
}
async fn cancel_response(&self, _headers: Option<&HeaderMap>, _response_id: &str) -> Response {
(
StatusCode::NOT_IMPLEMENTED,
"Responses cancel endpoint not implemented for PD router",
)
.into_response()
}
async fn route_embeddings(
&self,
_headers: Option<&HeaderMap>,
_body: &EmbeddingRequest,
_model_id: Option<&str>,
) -> Response {
(
StatusCode::NOT_IMPLEMENTED,
"Embeddings endpoint not implemented for PD router",
)
.into_response()
}
async fn route_rerank(
&self,
headers: Option<&HeaderMap>,
body: &RerankRequest,
model_id: Option<&str>,
) -> Response {
// Extract text for cache-aware routing
let req_text = if self.policies_need_request_text() {
Some(body.query.clone())
} else {
None
};
let context = PDRequestContext {
route: "/v1/rerank",
batch_size: None,
is_stream: false,
return_logprob: false,
request_text: req_text,
model_id,
};
self.execute_dual_dispatch(headers, body, context).await
}
fn router_type(&self) -> &'static str {
"pd"
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::core::{BasicWorkerBuilder, WorkerType};
fn create_test_pd_router() -> PDRouter {
let worker_registry = Arc::new(WorkerRegistry::new());
let policy_registry =
Arc::new(PolicyRegistry::new(crate::config::PolicyConfig::RoundRobin));
PDRouter {
worker_registry,
policy_registry,
client: Client::new(),
retry_config: RetryConfig::default(),
api_key: Some("test_api_key".to_string()),
enable_igw: false,
}
}
fn create_test_worker(url: String, worker_type: WorkerType, healthy: bool) -> Box<dyn Worker> {
let worker = BasicWorkerBuilder::new(url)
.worker_type(worker_type)
.build();
worker.set_healthy(healthy);
Box::new(worker)
}
#[tokio::test]
async fn test_select_healthy_prefill_worker() {
let router = create_test_pd_router();
let healthy_worker = create_test_worker(
"http://healthy".to_string(),
WorkerType::Prefill {
bootstrap_port: None,
},
true,
);
let unhealthy_worker = create_test_worker(
"http://unhealthy".to_string(),
WorkerType::Prefill {
bootstrap_port: None,
},
false,
);
let decode_worker =
create_test_worker("http://decode".to_string(), WorkerType::Decode, true);
router.worker_registry.register(Arc::from(unhealthy_worker));
router.worker_registry.register(Arc::from(healthy_worker));
router.worker_registry.register(Arc::from(decode_worker));
let result = router.select_pd_pair(None, None).await;
assert!(result.is_ok());
let (prefill, _decode) = result.unwrap();
assert_eq!(prefill.url(), "http://healthy");
assert!(prefill.is_healthy());
}
#[tokio::test]
async fn test_empty_worker_lists() {
let router = create_test_pd_router();
let result = router.select_pd_pair(None, None).await;
assert!(result.is_err());
assert!(result.unwrap_err().contains("No prefill workers available"));
}
#[test]
fn test_worker_load_metrics() {
let prefill_worker = create_test_worker(
"http://prefill".to_string(),
WorkerType::Prefill {
bootstrap_port: None,
},
true,
);
let decode_worker =
create_test_worker("http://decode".to_string(), WorkerType::Decode, true);
let _guard =
WorkerLoadGuard::new_multi(vec![prefill_worker.as_ref(), decode_worker.as_ref()]);
assert_eq!(prefill_worker.load(), 1);
assert_eq!(decode_worker.load(), 1);
drop(_guard);
assert_eq!(prefill_worker.load(), 0);
assert_eq!(decode_worker.load(), 0);
}
#[tokio::test]
async fn test_streaming_load_tracking() {
use futures_util::StreamExt;
use tokio::time::{sleep, Duration};
let router = create_test_pd_router();
let prefill_worker = create_test_worker(
"http://prefill".to_string(),
WorkerType::Prefill {
bootstrap_port: None,
},
true,
);
let decode_worker =
create_test_worker("http://decode".to_string(), WorkerType::Decode, true);
router.worker_registry.register(Arc::from(prefill_worker));
router.worker_registry.register(Arc::from(decode_worker));
let prefill_workers = router.worker_registry.get_prefill_workers();
let decode_workers = router.worker_registry.get_decode_workers();
let prefill_ref = prefill_workers[0].clone();
let decode_ref = decode_workers[0].clone();
assert_eq!(prefill_ref.load(), 0);
assert_eq!(decode_ref.load(), 0);
let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
let stream = UnboundedReceiverStream::new(rx);
let _response = router.create_streaming_response(
stream.map(Ok),
StatusCode::OK,
None,
false,
None,
None,
prefill_ref.as_ref(),
decode_ref.as_ref(),
);
assert_eq!(prefill_ref.load(), 1);
assert_eq!(decode_ref.load(), 1);
tx.send(bytes::Bytes::from("test data")).unwrap();
sleep(Duration::from_millis(10)).await;
assert_eq!(prefill_ref.load(), 1);
assert_eq!(decode_ref.load(), 1);
drop(tx);
sleep(Duration::from_millis(100)).await;
assert_eq!(prefill_ref.load(), 0);
assert_eq!(decode_ref.load(), 0);
}
}