[pd-router] add retry and circuit breakfor for pd router (#9051)
This commit is contained in:
@@ -16,7 +16,7 @@ pub use circuit_breaker::{
|
|||||||
CircuitBreaker, CircuitBreakerConfig, CircuitBreakerStats, CircuitState,
|
CircuitBreaker, CircuitBreakerConfig, CircuitBreakerStats, CircuitState,
|
||||||
};
|
};
|
||||||
pub use error::{WorkerError, WorkerResult};
|
pub use error::{WorkerError, WorkerResult};
|
||||||
pub use retry::{BackoffCalculator, RetryError, RetryExecutor};
|
pub use retry::{is_retryable_status, BackoffCalculator, RetryError, RetryExecutor};
|
||||||
pub use worker::{
|
pub use worker::{
|
||||||
start_health_checker, BasicWorker, DPAwareWorker, HealthChecker, Worker, WorkerCollection,
|
start_health_checker, BasicWorker, DPAwareWorker, HealthChecker, Worker, WorkerCollection,
|
||||||
WorkerFactory, WorkerLoadGuard, WorkerType,
|
WorkerFactory, WorkerLoadGuard, WorkerType,
|
||||||
|
|||||||
@@ -1,9 +1,23 @@
|
|||||||
use crate::config::types::RetryConfig;
|
use crate::config::types::RetryConfig;
|
||||||
|
use axum::http::StatusCode;
|
||||||
use axum::response::Response;
|
use axum::response::Response;
|
||||||
use rand::Rng;
|
use rand::Rng;
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
use tracing::debug;
|
use tracing::debug;
|
||||||
|
|
||||||
|
/// Check if an HTTP status code indicates a retryable error
|
||||||
|
pub fn is_retryable_status(status: StatusCode) -> bool {
|
||||||
|
matches!(
|
||||||
|
status,
|
||||||
|
StatusCode::REQUEST_TIMEOUT
|
||||||
|
| StatusCode::TOO_MANY_REQUESTS
|
||||||
|
| StatusCode::INTERNAL_SERVER_ERROR
|
||||||
|
| StatusCode::BAD_GATEWAY
|
||||||
|
| StatusCode::SERVICE_UNAVAILABLE
|
||||||
|
| StatusCode::GATEWAY_TIMEOUT
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
/// Computes exponential backoff with optional jitter.
|
/// Computes exponential backoff with optional jitter.
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct BackoffCalculator;
|
pub struct BackoffCalculator;
|
||||||
@@ -21,8 +35,8 @@ impl BackoffCalculator {
|
|||||||
// Apply jitter in range [-j, +j]
|
// Apply jitter in range [-j, +j]
|
||||||
let jitter = config.jitter_factor.max(0.0).min(1.0);
|
let jitter = config.jitter_factor.max(0.0).min(1.0);
|
||||||
if jitter > 0.0 {
|
if jitter > 0.0 {
|
||||||
let mut rng = rand::thread_rng();
|
let mut rng = rand::rng();
|
||||||
let jitter_scale: f32 = rng.gen_range(-jitter..=jitter);
|
let jitter_scale: f32 = rng.random_range(-jitter..=jitter);
|
||||||
let jitter_ms = (delay_ms as f32 * jitter_scale)
|
let jitter_ms = (delay_ms as f32 * jitter_scale)
|
||||||
.round()
|
.round()
|
||||||
.max(-(delay_ms as f32));
|
.max(-(delay_ms as f32));
|
||||||
|
|||||||
@@ -2,7 +2,10 @@
|
|||||||
// This module handles routing for disaggregated prefill-decode systems
|
// This module handles routing for disaggregated prefill-decode systems
|
||||||
use super::pd_types::{api_path, PDRouterError};
|
use super::pd_types::{api_path, PDRouterError};
|
||||||
use crate::config::types::{CircuitBreakerConfig as ConfigCircuitBreakerConfig, RetryConfig};
|
use crate::config::types::{CircuitBreakerConfig as ConfigCircuitBreakerConfig, RetryConfig};
|
||||||
use crate::core::{CircuitBreakerConfig, HealthChecker, Worker, WorkerFactory, WorkerLoadGuard};
|
use crate::core::{
|
||||||
|
is_retryable_status, CircuitBreakerConfig, HealthChecker, RetryExecutor, Worker, WorkerFactory,
|
||||||
|
WorkerLoadGuard,
|
||||||
|
};
|
||||||
use crate::metrics::RouterMetrics;
|
use crate::metrics::RouterMetrics;
|
||||||
use crate::openai_api_types::{ChatCompletionRequest, CompletionRequest, GenerateRequest};
|
use crate::openai_api_types::{ChatCompletionRequest, CompletionRequest, GenerateRequest};
|
||||||
use crate::policies::LoadBalancingPolicy;
|
use crate::policies::LoadBalancingPolicy;
|
||||||
@@ -17,6 +20,7 @@ use axum::{
|
|||||||
};
|
};
|
||||||
use futures_util::StreamExt;
|
use futures_util::StreamExt;
|
||||||
use reqwest::Client;
|
use reqwest::Client;
|
||||||
|
use serde::Serialize;
|
||||||
use serde_json::Value;
|
use serde_json::Value;
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::sync::{Arc, RwLock};
|
use std::sync::{Arc, RwLock};
|
||||||
@@ -43,6 +47,16 @@ pub struct PDRouter {
|
|||||||
_decode_health_checker: Option<HealthChecker>,
|
_decode_health_checker: Option<HealthChecker>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Request context for PD router operations
|
||||||
|
#[derive(Clone)]
|
||||||
|
struct PDRequestContext {
|
||||||
|
route: &'static str,
|
||||||
|
batch_size: Option<usize>,
|
||||||
|
is_stream: bool,
|
||||||
|
return_logprob: bool,
|
||||||
|
request_text: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
impl PDRouter {
|
impl PDRouter {
|
||||||
// Dynamic worker management methods for service discovery
|
// Dynamic worker management methods for service discovery
|
||||||
|
|
||||||
@@ -218,12 +232,8 @@ impl PDRouter {
|
|||||||
let core_cb_config = CircuitBreakerConfig {
|
let core_cb_config = CircuitBreakerConfig {
|
||||||
failure_threshold: circuit_breaker_config.failure_threshold,
|
failure_threshold: circuit_breaker_config.failure_threshold,
|
||||||
success_threshold: circuit_breaker_config.success_threshold,
|
success_threshold: circuit_breaker_config.success_threshold,
|
||||||
timeout_duration: std::time::Duration::from_secs(
|
timeout_duration: Duration::from_secs(circuit_breaker_config.timeout_duration_secs),
|
||||||
circuit_breaker_config.timeout_duration_secs,
|
window_duration: Duration::from_secs(circuit_breaker_config.window_duration_secs),
|
||||||
),
|
|
||||||
window_duration: std::time::Duration::from_secs(
|
|
||||||
circuit_breaker_config.window_duration_secs,
|
|
||||||
),
|
|
||||||
};
|
};
|
||||||
|
|
||||||
// Convert URLs to Worker trait objects
|
// Convert URLs to Worker trait objects
|
||||||
@@ -459,8 +469,96 @@ impl PDRouter {
|
|||||||
Ok(original)
|
Ok(original)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Execute the dual dispatch to prefill and decode servers
|
// Execute the dual dispatch to prefill and decode servers with retries and bootstrap injection
|
||||||
async fn execute_dual_dispatch(
|
async fn execute_dual_dispatch<T: Serialize + Clone>(
|
||||||
|
&self,
|
||||||
|
headers: Option<&HeaderMap>,
|
||||||
|
original_request: &T,
|
||||||
|
context: PDRequestContext,
|
||||||
|
) -> Response {
|
||||||
|
let start_time = Instant::now();
|
||||||
|
|
||||||
|
let route = context.route;
|
||||||
|
RetryExecutor::execute_response_with_retry(
|
||||||
|
&self.retry_config,
|
||||||
|
// Operation per attempt
|
||||||
|
{
|
||||||
|
let original_request = original_request.clone();
|
||||||
|
move |attempt: u32| {
|
||||||
|
let original_request = original_request.clone();
|
||||||
|
let context = context.clone();
|
||||||
|
async move {
|
||||||
|
// Select workers fresh for each attempt
|
||||||
|
let (prefill, decode) =
|
||||||
|
match self.select_pd_pair(context.request_text.as_deref()).await {
|
||||||
|
Ok(pair) => pair,
|
||||||
|
Err(e) => {
|
||||||
|
RouterMetrics::record_pd_error("server_selection");
|
||||||
|
return Self::handle_server_selection_error(e);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
debug!(
|
||||||
|
"PD retry attempt {} using prefill={} decode={}",
|
||||||
|
attempt,
|
||||||
|
prefill.url(),
|
||||||
|
decode.url()
|
||||||
|
);
|
||||||
|
|
||||||
|
// Serialize the original request
|
||||||
|
let mut json_request = match serde_json::to_value(&original_request) {
|
||||||
|
Ok(v) => v,
|
||||||
|
Err(e) => return Self::handle_serialization_error(e),
|
||||||
|
};
|
||||||
|
|
||||||
|
// Inject bootstrap based on current prefill worker
|
||||||
|
json_request = match Self::inject_bootstrap_into_value(
|
||||||
|
json_request,
|
||||||
|
prefill.as_ref(),
|
||||||
|
context.batch_size,
|
||||||
|
) {
|
||||||
|
Ok(v) => v,
|
||||||
|
Err(e) => return Self::handle_serialization_error(e),
|
||||||
|
};
|
||||||
|
|
||||||
|
// Execute the actual dual dispatch
|
||||||
|
let response = self
|
||||||
|
.execute_dual_dispatch_internal(
|
||||||
|
headers,
|
||||||
|
json_request,
|
||||||
|
context.route,
|
||||||
|
prefill.as_ref(),
|
||||||
|
decode.as_ref(),
|
||||||
|
context.is_stream,
|
||||||
|
context.return_logprob,
|
||||||
|
start_time,
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
// Record outcomes for circuit breakers
|
||||||
|
let is_success = response.status().is_success();
|
||||||
|
prefill.record_outcome(is_success);
|
||||||
|
decode.record_outcome(is_success);
|
||||||
|
|
||||||
|
response
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
// Should retry predicate
|
||||||
|
|res, _attempt| is_retryable_status(res.status()),
|
||||||
|
// On backoff hook
|
||||||
|
|delay, attempt| {
|
||||||
|
RouterMetrics::record_retry(route);
|
||||||
|
RouterMetrics::record_retry_backoff_duration(delay, attempt);
|
||||||
|
},
|
||||||
|
// On exhausted hook
|
||||||
|
|| RouterMetrics::record_retries_exhausted(route),
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
}
|
||||||
|
|
||||||
|
// Internal method that performs the actual dual dispatch (without retry logic)
|
||||||
|
async fn execute_dual_dispatch_internal(
|
||||||
&self,
|
&self,
|
||||||
headers: Option<&HeaderMap>,
|
headers: Option<&HeaderMap>,
|
||||||
json_request: Value,
|
json_request: Value,
|
||||||
@@ -696,7 +794,7 @@ impl PDRouter {
|
|||||||
self.prefill_policy.needs_request_text() || self.decode_policy.needs_request_text()
|
self.prefill_policy.needs_request_text() || self.decode_policy.needs_request_text()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Select a pair of prefill and decode servers
|
// Select a pair of prefill and decode servers considering circuit breaker state
|
||||||
async fn select_pd_pair(
|
async fn select_pd_pair(
|
||||||
&self,
|
&self,
|
||||||
request_text: Option<&str>,
|
request_text: Option<&str>,
|
||||||
@@ -711,31 +809,60 @@ impl PDRouter {
|
|||||||
.read()
|
.read()
|
||||||
.map_err(|e| format!("Failed to acquire decode workers lock: {}", e))?;
|
.map_err(|e| format!("Failed to acquire decode workers lock: {}", e))?;
|
||||||
|
|
||||||
// Check we have workers
|
// Select workers using helper function
|
||||||
if prefill_workers.is_empty() {
|
let prefill = Self::pick_worker_by_policy(
|
||||||
return Err("No prefill workers available. Please check if prefill servers are configured and healthy.".to_string());
|
&*prefill_workers,
|
||||||
}
|
&*self.prefill_policy,
|
||||||
if decode_workers.is_empty() {
|
request_text,
|
||||||
return Err("No decode workers available. Please check if decode servers are configured and healthy.".to_string());
|
"prefill",
|
||||||
}
|
)?;
|
||||||
|
|
||||||
// Select prefill worker using prefill policy
|
let decode = Self::pick_worker_by_policy(
|
||||||
let prefill_idx = self
|
&*decode_workers,
|
||||||
.prefill_policy
|
&*self.decode_policy,
|
||||||
.select_worker(&prefill_workers, request_text)
|
request_text,
|
||||||
.ok_or("Failed to select prefill worker")?;
|
"decode",
|
||||||
|
)?;
|
||||||
|
|
||||||
// Select decode worker using decode policy
|
|
||||||
let decode_idx = self
|
|
||||||
.decode_policy
|
|
||||||
.select_worker(&decode_workers, request_text)
|
|
||||||
.ok_or("Failed to select decode worker")?;
|
|
||||||
|
|
||||||
let prefill = prefill_workers[prefill_idx].clone_worker();
|
|
||||||
let decode = decode_workers[decode_idx].clone_worker();
|
|
||||||
Ok((prefill, decode))
|
Ok((prefill, decode))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Helper function to select a worker using the policy
|
||||||
|
fn pick_worker_by_policy(
|
||||||
|
workers: &[Box<dyn Worker>],
|
||||||
|
policy: &dyn LoadBalancingPolicy,
|
||||||
|
request_text: Option<&str>,
|
||||||
|
worker_type: &str,
|
||||||
|
) -> Result<Box<dyn Worker>, String> {
|
||||||
|
// Check if we have any workers
|
||||||
|
if workers.is_empty() {
|
||||||
|
return Err(format!(
|
||||||
|
"No {} workers available. Please check if {} servers are configured and healthy.",
|
||||||
|
worker_type, worker_type
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Filter available workers (healthy + circuit breaker not open)
|
||||||
|
let available_workers: Vec<Box<dyn Worker>> = workers
|
||||||
|
.iter()
|
||||||
|
.filter(|w| w.is_available())
|
||||||
|
.map(|w| w.clone_worker())
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
if available_workers.is_empty() {
|
||||||
|
return Err(format!(
|
||||||
|
"No available {} workers (all circuits open or unhealthy)",
|
||||||
|
worker_type
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Let policy select from available workers only
|
||||||
|
match policy.select_worker(&available_workers, request_text) {
|
||||||
|
Some(idx) => Ok(available_workers[idx].clone_worker()),
|
||||||
|
None => Err(format!("Policy could not select a {} worker", worker_type)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Background task to monitor worker loads with shared client
|
// Background task to monitor worker loads with shared client
|
||||||
async fn monitor_worker_loads_with_client(
|
async fn monitor_worker_loads_with_client(
|
||||||
worker_urls: Vec<String>,
|
worker_urls: Vec<String>,
|
||||||
@@ -1449,61 +1576,41 @@ impl RouterTrait for PDRouter {
|
|||||||
headers: Option<&HeaderMap>,
|
headers: Option<&HeaderMap>,
|
||||||
body: &GenerateRequest,
|
body: &GenerateRequest,
|
||||||
) -> Response {
|
) -> Response {
|
||||||
let start = Instant::now();
|
// Extract parameters
|
||||||
|
|
||||||
// Extract flags for routing logic
|
|
||||||
let is_stream = body.stream;
|
let is_stream = body.stream;
|
||||||
let return_logprob = body.return_logprob;
|
let return_logprob = body.return_logprob;
|
||||||
|
|
||||||
// Extract text for cache-aware routing only if needed
|
// Extract text for cache-aware routing
|
||||||
let request_text = if self.policies_need_request_text() {
|
let request_text = if self.policies_need_request_text() {
|
||||||
body.text.as_deref().or_else(|| {
|
body.text
|
||||||
body.prompt.as_ref().and_then(|p| match p {
|
.as_deref()
|
||||||
crate::openai_api_types::StringOrArray::String(s) => Some(s.as_str()),
|
.or_else(|| {
|
||||||
crate::openai_api_types::StringOrArray::Array(v) => {
|
body.prompt.as_ref().and_then(|p| match p {
|
||||||
v.first().map(|s| s.as_str())
|
crate::openai_api_types::StringOrArray::String(s) => Some(s.as_str()),
|
||||||
}
|
crate::openai_api_types::StringOrArray::Array(v) => {
|
||||||
|
v.first().map(|s| s.as_str())
|
||||||
|
}
|
||||||
|
})
|
||||||
})
|
})
|
||||||
})
|
.map(|s| s.to_string())
|
||||||
} else {
|
} else {
|
||||||
None
|
None
|
||||||
};
|
};
|
||||||
|
|
||||||
// Select servers
|
// Calculate batch size
|
||||||
let (prefill, decode) = match self.select_pd_pair(request_text).await {
|
|
||||||
Ok(pair) => pair,
|
|
||||||
Err(e) => return Self::handle_server_selection_error(e),
|
|
||||||
};
|
|
||||||
|
|
||||||
// Log routing decision
|
|
||||||
info!(
|
|
||||||
"PD routing decision route=/generate prefill_url={} decode_url={}",
|
|
||||||
prefill.url(),
|
|
||||||
decode.url()
|
|
||||||
);
|
|
||||||
|
|
||||||
let batch_size = Self::get_generate_batch_size(body);
|
let batch_size = Self::get_generate_batch_size(body);
|
||||||
let original = match serde_json::to_value(body) {
|
|
||||||
Ok(v) => v,
|
|
||||||
Err(e) => return Self::handle_serialization_error(e),
|
|
||||||
};
|
|
||||||
let json = match Self::inject_bootstrap_into_value(original, prefill.as_ref(), batch_size) {
|
|
||||||
Ok(v) => v,
|
|
||||||
Err(e) => return Self::handle_serialization_error(e),
|
|
||||||
};
|
|
||||||
|
|
||||||
// Execute dual dispatch
|
// Create context
|
||||||
self.execute_dual_dispatch(
|
let context = PDRequestContext {
|
||||||
headers,
|
route: "/generate",
|
||||||
json,
|
batch_size,
|
||||||
"/generate",
|
|
||||||
prefill.as_ref(),
|
|
||||||
decode.as_ref(),
|
|
||||||
is_stream,
|
is_stream,
|
||||||
return_logprob,
|
return_logprob,
|
||||||
start,
|
request_text,
|
||||||
)
|
};
|
||||||
.await
|
|
||||||
|
// Execute with retry and bootstrap injection
|
||||||
|
self.execute_dual_dispatch(headers, body, context).await
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn route_chat(
|
async fn route_chat(
|
||||||
@@ -1511,25 +1618,19 @@ impl RouterTrait for PDRouter {
|
|||||||
headers: Option<&HeaderMap>,
|
headers: Option<&HeaderMap>,
|
||||||
body: &ChatCompletionRequest,
|
body: &ChatCompletionRequest,
|
||||||
) -> Response {
|
) -> Response {
|
||||||
let start = Instant::now();
|
// Extract parameters
|
||||||
|
|
||||||
// Extract flags for routing logic
|
|
||||||
let is_stream = body.stream;
|
let is_stream = body.stream;
|
||||||
let return_logprob = body.logprobs;
|
let return_logprob = body.logprobs;
|
||||||
|
|
||||||
// Extract text for cache-aware routing from chat messages only if needed
|
// Extract text for cache-aware routing
|
||||||
let request_text = if self.policies_need_request_text() {
|
let request_text = if self.policies_need_request_text() {
|
||||||
body.messages.first().and_then(|msg| match msg {
|
body.messages.first().and_then(|msg| match msg {
|
||||||
crate::openai_api_types::ChatMessage::User { content, .. } => {
|
crate::openai_api_types::ChatMessage::User { content, .. } => match content {
|
||||||
match content {
|
crate::openai_api_types::UserMessageContent::Text(text) => Some(text.clone()),
|
||||||
crate::openai_api_types::UserMessageContent::Text(text) => {
|
crate::openai_api_types::UserMessageContent::Parts(_) => None,
|
||||||
Some(text.as_str())
|
},
|
||||||
}
|
|
||||||
crate::openai_api_types::UserMessageContent::Parts(_) => None, // Skip complex content
|
|
||||||
}
|
|
||||||
}
|
|
||||||
crate::openai_api_types::ChatMessage::System { content, .. } => {
|
crate::openai_api_types::ChatMessage::System { content, .. } => {
|
||||||
Some(content.as_str())
|
Some(content.clone())
|
||||||
}
|
}
|
||||||
_ => None,
|
_ => None,
|
||||||
})
|
})
|
||||||
@@ -1537,41 +1638,20 @@ impl RouterTrait for PDRouter {
|
|||||||
None
|
None
|
||||||
};
|
};
|
||||||
|
|
||||||
// Select servers
|
// Calculate batch size
|
||||||
let (prefill, decode) = match self.select_pd_pair(request_text).await {
|
|
||||||
Ok(pair) => pair,
|
|
||||||
Err(e) => return Self::handle_server_selection_error(e),
|
|
||||||
};
|
|
||||||
|
|
||||||
// Log routing decision
|
|
||||||
info!(
|
|
||||||
"PD routing decision route=/v1/chat/completions prefill_url={} decode_url={}",
|
|
||||||
prefill.url(),
|
|
||||||
decode.url()
|
|
||||||
);
|
|
||||||
|
|
||||||
let batch_size = Self::get_chat_batch_size(body);
|
let batch_size = Self::get_chat_batch_size(body);
|
||||||
let original = match serde_json::to_value(body) {
|
|
||||||
Ok(v) => v,
|
|
||||||
Err(e) => return Self::handle_serialization_error(e),
|
|
||||||
};
|
|
||||||
let json = match Self::inject_bootstrap_into_value(original, prefill.as_ref(), batch_size) {
|
|
||||||
Ok(v) => v,
|
|
||||||
Err(e) => return Self::handle_serialization_error(e),
|
|
||||||
};
|
|
||||||
|
|
||||||
// Execute dual dispatch
|
// Create context
|
||||||
self.execute_dual_dispatch(
|
let context = PDRequestContext {
|
||||||
headers,
|
route: "/v1/chat/completions",
|
||||||
json,
|
batch_size,
|
||||||
"/v1/chat/completions",
|
|
||||||
prefill.as_ref(),
|
|
||||||
decode.as_ref(),
|
|
||||||
is_stream,
|
is_stream,
|
||||||
return_logprob,
|
return_logprob,
|
||||||
start,
|
request_text,
|
||||||
)
|
};
|
||||||
.await
|
|
||||||
|
// Execute with retry and bootstrap injection
|
||||||
|
self.execute_dual_dispatch(headers, body, context).await
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn route_completion(
|
async fn route_completion(
|
||||||
@@ -1579,57 +1659,36 @@ impl RouterTrait for PDRouter {
|
|||||||
headers: Option<&HeaderMap>,
|
headers: Option<&HeaderMap>,
|
||||||
body: &CompletionRequest,
|
body: &CompletionRequest,
|
||||||
) -> Response {
|
) -> Response {
|
||||||
let start = Instant::now();
|
// Extract parameters
|
||||||
|
|
||||||
// Extract flags for routing logic
|
|
||||||
let is_stream = body.stream;
|
let is_stream = body.stream;
|
||||||
let return_logprob = body.logprobs.is_some();
|
let return_logprob = body.logprobs.is_some();
|
||||||
|
|
||||||
// Extract text for cache-aware routing only if needed
|
// Extract text for cache-aware routing
|
||||||
let request_text = if self.policies_need_request_text() {
|
let request_text = if self.policies_need_request_text() {
|
||||||
match &body.prompt {
|
match &body.prompt {
|
||||||
crate::openai_api_types::StringOrArray::String(s) => Some(s.as_str()),
|
crate::openai_api_types::StringOrArray::String(s) => Some(s.clone()),
|
||||||
crate::openai_api_types::StringOrArray::Array(v) => v.first().map(|s| s.as_str()),
|
crate::openai_api_types::StringOrArray::Array(v) => {
|
||||||
|
v.first().map(|s| s.to_string())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
None
|
None
|
||||||
};
|
};
|
||||||
|
|
||||||
// Select servers
|
// Calculate batch size
|
||||||
let (prefill, decode) = match self.select_pd_pair(request_text).await {
|
|
||||||
Ok(pair) => pair,
|
|
||||||
Err(e) => return Self::handle_server_selection_error(e),
|
|
||||||
};
|
|
||||||
|
|
||||||
// Log routing decision
|
|
||||||
info!(
|
|
||||||
"PD routing decision route=/v1/completions prefill_url={} decode_url={}",
|
|
||||||
prefill.url(),
|
|
||||||
decode.url()
|
|
||||||
);
|
|
||||||
|
|
||||||
let batch_size = Self::get_completion_batch_size(body);
|
let batch_size = Self::get_completion_batch_size(body);
|
||||||
let original = match serde_json::to_value(body) {
|
|
||||||
Ok(v) => v,
|
|
||||||
Err(e) => return Self::handle_serialization_error(e),
|
|
||||||
};
|
|
||||||
let json = match Self::inject_bootstrap_into_value(original, prefill.as_ref(), batch_size) {
|
|
||||||
Ok(v) => v,
|
|
||||||
Err(e) => return Self::handle_serialization_error(e),
|
|
||||||
};
|
|
||||||
|
|
||||||
// Execute dual dispatch
|
// Create context
|
||||||
self.execute_dual_dispatch(
|
let context = PDRequestContext {
|
||||||
headers,
|
route: "/v1/completions",
|
||||||
json,
|
batch_size,
|
||||||
"/v1/completions",
|
|
||||||
prefill.as_ref(),
|
|
||||||
decode.as_ref(),
|
|
||||||
is_stream,
|
is_stream,
|
||||||
return_logprob,
|
return_logprob,
|
||||||
start,
|
request_text,
|
||||||
)
|
};
|
||||||
.await
|
|
||||||
|
// Execute with retry and bootstrap injection
|
||||||
|
self.execute_dual_dispatch(headers, body, context).await
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn flush_cache(&self) -> Response {
|
async fn flush_cache(&self) -> Response {
|
||||||
|
|||||||
@@ -1,5 +1,7 @@
|
|||||||
use crate::config::types::{CircuitBreakerConfig as ConfigCircuitBreakerConfig, RetryConfig};
|
use crate::config::types::{CircuitBreakerConfig as ConfigCircuitBreakerConfig, RetryConfig};
|
||||||
use crate::core::{CircuitBreakerConfig, HealthChecker, RetryExecutor, Worker, WorkerFactory};
|
use crate::core::{
|
||||||
|
is_retryable_status, CircuitBreakerConfig, HealthChecker, RetryExecutor, Worker, WorkerFactory,
|
||||||
|
};
|
||||||
use crate::metrics::RouterMetrics;
|
use crate::metrics::RouterMetrics;
|
||||||
use crate::openai_api_types::{ChatCompletionRequest, CompletionRequest, GenerateRequest};
|
use crate::openai_api_types::{ChatCompletionRequest, CompletionRequest, GenerateRequest};
|
||||||
use crate::policies::LoadBalancingPolicy;
|
use crate::policies::LoadBalancingPolicy;
|
||||||
@@ -81,12 +83,8 @@ impl Router {
|
|||||||
let core_cb_config = CircuitBreakerConfig {
|
let core_cb_config = CircuitBreakerConfig {
|
||||||
failure_threshold: circuit_breaker_config.failure_threshold,
|
failure_threshold: circuit_breaker_config.failure_threshold,
|
||||||
success_threshold: circuit_breaker_config.success_threshold,
|
success_threshold: circuit_breaker_config.success_threshold,
|
||||||
timeout_duration: std::time::Duration::from_secs(
|
timeout_duration: Duration::from_secs(circuit_breaker_config.timeout_duration_secs),
|
||||||
circuit_breaker_config.timeout_duration_secs,
|
window_duration: Duration::from_secs(circuit_breaker_config.window_duration_secs),
|
||||||
),
|
|
||||||
window_duration: std::time::Duration::from_secs(
|
|
||||||
circuit_breaker_config.window_duration_secs,
|
|
||||||
),
|
|
||||||
};
|
};
|
||||||
|
|
||||||
// Create Worker trait objects from URLs
|
// Create Worker trait objects from URLs
|
||||||
@@ -397,18 +395,6 @@ impl Router {
|
|||||||
Some(available[idx].clone_worker())
|
Some(available[idx].clone_worker())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn is_retryable_status(status: StatusCode) -> bool {
|
|
||||||
matches!(
|
|
||||||
status,
|
|
||||||
StatusCode::REQUEST_TIMEOUT
|
|
||||||
| StatusCode::TOO_MANY_REQUESTS
|
|
||||||
| StatusCode::INTERNAL_SERVER_ERROR
|
|
||||||
| StatusCode::BAD_GATEWAY
|
|
||||||
| StatusCode::SERVICE_UNAVAILABLE
|
|
||||||
| StatusCode::GATEWAY_TIMEOUT
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn route_typed_request<
|
pub async fn route_typed_request<
|
||||||
T: crate::openai_api_types::GenerationRequest + serde::Serialize + Clone,
|
T: crate::openai_api_types::GenerationRequest + serde::Serialize + Clone,
|
||||||
>(
|
>(
|
||||||
@@ -461,7 +447,7 @@ impl Router {
|
|||||||
response
|
response
|
||||||
},
|
},
|
||||||
// should_retry predicate
|
// should_retry predicate
|
||||||
|res, _attempt| Self::is_retryable_status(res.status()),
|
|res, _attempt| is_retryable_status(res.status()),
|
||||||
// on_backoff hook
|
// on_backoff hook
|
||||||
|delay, attempt| {
|
|delay, attempt| {
|
||||||
RouterMetrics::record_retry(route);
|
RouterMetrics::record_retry(route);
|
||||||
@@ -476,7 +462,7 @@ impl Router {
|
|||||||
let duration = start.elapsed();
|
let duration = start.elapsed();
|
||||||
RouterMetrics::record_request(route);
|
RouterMetrics::record_request(route);
|
||||||
RouterMetrics::record_generate_duration(duration);
|
RouterMetrics::record_generate_duration(duration);
|
||||||
} else if !Self::is_retryable_status(response.status()) {
|
} else if !is_retryable_status(response.status()) {
|
||||||
RouterMetrics::record_request_error(route, "non_retryable_error");
|
RouterMetrics::record_request_error(route, "non_retryable_error");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user