[router] improve router logs and request id header (#8415)
This commit is contained in:
@@ -93,6 +93,19 @@ python -m sglang_router.launch_router \
|
|||||||
--prometheus-port 9000
|
--prometheus-port 9000
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Request ID Tracking
|
||||||
|
|
||||||
|
Track requests across distributed systems with configurable headers:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Use custom request ID headers
|
||||||
|
python -m sglang_router.launch_router \
|
||||||
|
--worker-urls http://localhost:8080 \
|
||||||
|
--request-id-headers x-trace-id x-request-id
|
||||||
|
```
|
||||||
|
|
||||||
|
Default headers: `x-request-id`, `x-correlation-id`, `x-trace-id`, `request-id`
|
||||||
|
|
||||||
## Advanced Features
|
## Advanced Features
|
||||||
|
|
||||||
### Kubernetes Service Discovery
|
### Kubernetes Service Discovery
|
||||||
|
|||||||
@@ -64,6 +64,8 @@ class RouterArgs:
|
|||||||
# Prometheus configuration
|
# Prometheus configuration
|
||||||
prometheus_port: Optional[int] = None
|
prometheus_port: Optional[int] = None
|
||||||
prometheus_host: Optional[str] = None
|
prometheus_host: Optional[str] = None
|
||||||
|
# Request ID headers configuration
|
||||||
|
request_id_headers: Optional[List[str]] = None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def add_cli_args(
|
def add_cli_args(
|
||||||
@@ -255,6 +257,12 @@ class RouterArgs:
|
|||||||
default="127.0.0.1",
|
default="127.0.0.1",
|
||||||
help="Host address to bind the Prometheus metrics server",
|
help="Host address to bind the Prometheus metrics server",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
f"--{prefix}request-id-headers",
|
||||||
|
type=str,
|
||||||
|
nargs="*",
|
||||||
|
help="Custom HTTP headers to check for request IDs (e.g., x-request-id x-trace-id). If not specified, uses common defaults.",
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_cli_args(
|
def from_cli_args(
|
||||||
@@ -313,6 +321,7 @@ class RouterArgs:
|
|||||||
bootstrap_port_annotation="sglang.ai/bootstrap-port", # Mooncake-specific annotation
|
bootstrap_port_annotation="sglang.ai/bootstrap-port", # Mooncake-specific annotation
|
||||||
prometheus_port=getattr(args, f"{prefix}prometheus_port", None),
|
prometheus_port=getattr(args, f"{prefix}prometheus_port", None),
|
||||||
prometheus_host=getattr(args, f"{prefix}prometheus_host", None),
|
prometheus_host=getattr(args, f"{prefix}prometheus_host", None),
|
||||||
|
request_id_headers=getattr(args, f"{prefix}request_id_headers", None),
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -481,6 +490,7 @@ def launch_router(args: argparse.Namespace) -> Optional[Router]:
|
|||||||
if router_args.decode_policy
|
if router_args.decode_policy
|
||||||
else None
|
else None
|
||||||
),
|
),
|
||||||
|
request_id_headers=router_args.request_id_headers,
|
||||||
)
|
)
|
||||||
|
|
||||||
router.start()
|
router.start()
|
||||||
|
|||||||
@@ -54,6 +54,9 @@ class Router:
|
|||||||
If not specified, uses the main policy. Default: None
|
If not specified, uses the main policy. Default: None
|
||||||
decode_policy: Specific load balancing policy for decode nodes (PD mode only).
|
decode_policy: Specific load balancing policy for decode nodes (PD mode only).
|
||||||
If not specified, uses the main policy. Default: None
|
If not specified, uses the main policy. Default: None
|
||||||
|
request_id_headers: List of HTTP headers to check for request IDs. If not specified,
|
||||||
|
uses common defaults: ['x-request-id', 'x-correlation-id', 'x-trace-id', 'request-id'].
|
||||||
|
Example: ['x-my-request-id', 'x-custom-trace-id']. Default: None
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -85,6 +88,7 @@ class Router:
|
|||||||
decode_urls: Optional[List[str]] = None,
|
decode_urls: Optional[List[str]] = None,
|
||||||
prefill_policy: Optional[PolicyType] = None,
|
prefill_policy: Optional[PolicyType] = None,
|
||||||
decode_policy: Optional[PolicyType] = None,
|
decode_policy: Optional[PolicyType] = None,
|
||||||
|
request_id_headers: Optional[List[str]] = None,
|
||||||
):
|
):
|
||||||
if selector is None:
|
if selector is None:
|
||||||
selector = {}
|
selector = {}
|
||||||
@@ -121,6 +125,7 @@ class Router:
|
|||||||
decode_urls=decode_urls,
|
decode_urls=decode_urls,
|
||||||
prefill_policy=prefill_policy,
|
prefill_policy=prefill_policy,
|
||||||
decode_policy=decode_policy,
|
decode_policy=decode_policy,
|
||||||
|
request_id_headers=request_id_headers,
|
||||||
)
|
)
|
||||||
|
|
||||||
def start(self) -> None:
|
def start(self) -> None:
|
||||||
|
|||||||
@@ -29,6 +29,8 @@ pub struct RouterConfig {
|
|||||||
pub log_dir: Option<String>,
|
pub log_dir: Option<String>,
|
||||||
/// Log level (None = info)
|
/// Log level (None = info)
|
||||||
pub log_level: Option<String>,
|
pub log_level: Option<String>,
|
||||||
|
/// Custom request ID headers to check (defaults to common headers)
|
||||||
|
pub request_id_headers: Option<Vec<String>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Routing mode configuration
|
/// Routing mode configuration
|
||||||
@@ -207,6 +209,7 @@ impl Default for RouterConfig {
|
|||||||
metrics: None,
|
metrics: None,
|
||||||
log_dir: None,
|
log_dir: None,
|
||||||
log_level: None,
|
log_level: None,
|
||||||
|
request_id_headers: None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -312,6 +315,7 @@ mod tests {
|
|||||||
metrics: Some(MetricsConfig::default()),
|
metrics: Some(MetricsConfig::default()),
|
||||||
log_dir: Some("/var/log".to_string()),
|
log_dir: Some("/var/log".to_string()),
|
||||||
log_level: Some("debug".to_string()),
|
log_level: Some("debug".to_string()),
|
||||||
|
request_id_headers: None,
|
||||||
};
|
};
|
||||||
|
|
||||||
let json = serde_json::to_string(&config).unwrap();
|
let json = serde_json::to_string(&config).unwrap();
|
||||||
@@ -734,6 +738,7 @@ mod tests {
|
|||||||
}),
|
}),
|
||||||
log_dir: Some("/var/log/sglang".to_string()),
|
log_dir: Some("/var/log/sglang".to_string()),
|
||||||
log_level: Some("info".to_string()),
|
log_level: Some("info".to_string()),
|
||||||
|
request_id_headers: None,
|
||||||
};
|
};
|
||||||
|
|
||||||
assert!(config.mode.is_pd_mode());
|
assert!(config.mode.is_pd_mode());
|
||||||
@@ -780,6 +785,7 @@ mod tests {
|
|||||||
metrics: Some(MetricsConfig::default()),
|
metrics: Some(MetricsConfig::default()),
|
||||||
log_dir: None,
|
log_dir: None,
|
||||||
log_level: Some("debug".to_string()),
|
log_level: Some("debug".to_string()),
|
||||||
|
request_id_headers: None,
|
||||||
};
|
};
|
||||||
|
|
||||||
assert!(!config.mode.is_pd_mode());
|
assert!(!config.mode.is_pd_mode());
|
||||||
@@ -822,6 +828,7 @@ mod tests {
|
|||||||
}),
|
}),
|
||||||
log_dir: Some("/opt/logs/sglang".to_string()),
|
log_dir: Some("/opt/logs/sglang".to_string()),
|
||||||
log_level: Some("trace".to_string()),
|
log_level: Some("trace".to_string()),
|
||||||
|
request_id_headers: None,
|
||||||
};
|
};
|
||||||
|
|
||||||
assert!(config.has_service_discovery());
|
assert!(config.has_service_discovery());
|
||||||
|
|||||||
@@ -411,7 +411,7 @@ pub fn start_health_checker(
|
|||||||
|
|
||||||
// Check for shutdown signal
|
// Check for shutdown signal
|
||||||
if shutdown_clone.load(Ordering::Acquire) {
|
if shutdown_clone.load(Ordering::Acquire) {
|
||||||
tracing::info!("Health checker shutting down");
|
tracing::debug!("Health checker shutting down");
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -439,6 +439,9 @@ pub fn start_health_checker(
|
|||||||
Err(e) => {
|
Err(e) => {
|
||||||
if was_healthy {
|
if was_healthy {
|
||||||
tracing::warn!("Worker {} health check failed: {}", worker_url, e);
|
tracing::warn!("Worker {} health check failed: {}", worker_url, e);
|
||||||
|
} else {
|
||||||
|
// Worker was already unhealthy, log at debug level
|
||||||
|
tracing::debug!("Worker {} remains unhealthy: {}", worker_url, e);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ pub mod logging;
|
|||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
pub mod core;
|
pub mod core;
|
||||||
pub mod metrics;
|
pub mod metrics;
|
||||||
|
pub mod middleware;
|
||||||
pub mod openai_api_types;
|
pub mod openai_api_types;
|
||||||
pub mod policies;
|
pub mod policies;
|
||||||
pub mod routers;
|
pub mod routers;
|
||||||
@@ -49,6 +50,7 @@ struct Router {
|
|||||||
prometheus_port: Option<u16>,
|
prometheus_port: Option<u16>,
|
||||||
prometheus_host: Option<String>,
|
prometheus_host: Option<String>,
|
||||||
request_timeout_secs: u64,
|
request_timeout_secs: u64,
|
||||||
|
request_id_headers: Option<Vec<String>>,
|
||||||
// PD mode flag
|
// PD mode flag
|
||||||
pd_disaggregation: bool,
|
pd_disaggregation: bool,
|
||||||
// PD-specific fields (only used when pd_disaggregation is true)
|
// PD-specific fields (only used when pd_disaggregation is true)
|
||||||
@@ -138,6 +140,7 @@ impl Router {
|
|||||||
metrics,
|
metrics,
|
||||||
log_dir: self.log_dir.clone(),
|
log_dir: self.log_dir.clone(),
|
||||||
log_level: self.log_level.clone(),
|
log_level: self.log_level.clone(),
|
||||||
|
request_id_headers: self.request_id_headers.clone(),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -170,6 +173,7 @@ impl Router {
|
|||||||
prometheus_port = None,
|
prometheus_port = None,
|
||||||
prometheus_host = None,
|
prometheus_host = None,
|
||||||
request_timeout_secs = 600, // Add configurable request timeout
|
request_timeout_secs = 600, // Add configurable request timeout
|
||||||
|
request_id_headers = None, // Custom request ID headers
|
||||||
pd_disaggregation = false, // New flag for PD mode
|
pd_disaggregation = false, // New flag for PD mode
|
||||||
prefill_urls = None,
|
prefill_urls = None,
|
||||||
decode_urls = None,
|
decode_urls = None,
|
||||||
@@ -201,6 +205,7 @@ impl Router {
|
|||||||
prometheus_port: Option<u16>,
|
prometheus_port: Option<u16>,
|
||||||
prometheus_host: Option<String>,
|
prometheus_host: Option<String>,
|
||||||
request_timeout_secs: u64,
|
request_timeout_secs: u64,
|
||||||
|
request_id_headers: Option<Vec<String>>,
|
||||||
pd_disaggregation: bool,
|
pd_disaggregation: bool,
|
||||||
prefill_urls: Option<Vec<(String, Option<u16>)>>,
|
prefill_urls: Option<Vec<(String, Option<u16>)>>,
|
||||||
decode_urls: Option<Vec<String>>,
|
decode_urls: Option<Vec<String>>,
|
||||||
@@ -232,6 +237,7 @@ impl Router {
|
|||||||
prometheus_port,
|
prometheus_port,
|
||||||
prometheus_host,
|
prometheus_host,
|
||||||
request_timeout_secs,
|
request_timeout_secs,
|
||||||
|
request_id_headers,
|
||||||
pd_disaggregation,
|
pd_disaggregation,
|
||||||
prefill_urls,
|
prefill_urls,
|
||||||
decode_urls,
|
decode_urls,
|
||||||
@@ -297,6 +303,7 @@ impl Router {
|
|||||||
service_discovery_config,
|
service_discovery_config,
|
||||||
prometheus_config,
|
prometheus_config,
|
||||||
request_timeout_secs: self.request_timeout_secs,
|
request_timeout_secs: self.request_timeout_secs,
|
||||||
|
request_id_headers: self.request_id_headers.clone(),
|
||||||
})
|
})
|
||||||
.await
|
.await
|
||||||
.map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))
|
.map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))
|
||||||
|
|||||||
111
sgl-router/src/middleware.rs
Normal file
111
sgl-router/src/middleware.rs
Normal file
@@ -0,0 +1,111 @@
|
|||||||
|
use actix_web::{
|
||||||
|
dev::{forward_ready, Service, ServiceRequest, ServiceResponse, Transform},
|
||||||
|
Error, HttpMessage, HttpRequest,
|
||||||
|
};
|
||||||
|
use futures_util::future::LocalBoxFuture;
|
||||||
|
use std::future::{ready, Ready};
|
||||||
|
|
||||||
|
/// Generate OpenAI-compatible request ID based on endpoint
|
||||||
|
fn generate_request_id(path: &str) -> String {
|
||||||
|
let prefix = if path.contains("/chat/completions") {
|
||||||
|
"chatcmpl-"
|
||||||
|
} else if path.contains("/completions") {
|
||||||
|
"cmpl-"
|
||||||
|
} else if path.contains("/generate") {
|
||||||
|
"gnt-"
|
||||||
|
} else {
|
||||||
|
"req-"
|
||||||
|
};
|
||||||
|
|
||||||
|
// Generate a random string similar to OpenAI's format
|
||||||
|
let random_part: String = (0..24)
|
||||||
|
.map(|_| {
|
||||||
|
let chars = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789";
|
||||||
|
chars
|
||||||
|
.chars()
|
||||||
|
.nth(rand::random::<usize>() % chars.len())
|
||||||
|
.unwrap()
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
format!("{}{}", prefix, random_part)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Extract request ID from request extensions or generate a new one
|
||||||
|
pub fn get_request_id(req: &HttpRequest) -> String {
|
||||||
|
req.extensions()
|
||||||
|
.get::<String>()
|
||||||
|
.cloned()
|
||||||
|
.unwrap_or_else(|| generate_request_id(req.path()))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Middleware for injecting request ID into request extensions
|
||||||
|
pub struct RequestIdMiddleware {
|
||||||
|
headers: Vec<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl RequestIdMiddleware {
|
||||||
|
pub fn new(headers: Vec<String>) -> Self {
|
||||||
|
Self { headers }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<S, B> Transform<S, ServiceRequest> for RequestIdMiddleware
|
||||||
|
where
|
||||||
|
S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
|
||||||
|
S::Future: 'static,
|
||||||
|
B: 'static,
|
||||||
|
{
|
||||||
|
type Response = ServiceResponse<B>;
|
||||||
|
type Error = Error;
|
||||||
|
type InitError = ();
|
||||||
|
type Transform = RequestIdMiddlewareService<S>;
|
||||||
|
type Future = Ready<Result<Self::Transform, Self::InitError>>;
|
||||||
|
|
||||||
|
fn new_transform(&self, service: S) -> Self::Future {
|
||||||
|
ready(Ok(RequestIdMiddlewareService {
|
||||||
|
service,
|
||||||
|
headers: self.headers.clone(),
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct RequestIdMiddlewareService<S> {
|
||||||
|
service: S,
|
||||||
|
headers: Vec<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<S, B> Service<ServiceRequest> for RequestIdMiddlewareService<S>
|
||||||
|
where
|
||||||
|
S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
|
||||||
|
S::Future: 'static,
|
||||||
|
B: 'static,
|
||||||
|
{
|
||||||
|
type Response = ServiceResponse<B>;
|
||||||
|
type Error = Error;
|
||||||
|
type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;
|
||||||
|
|
||||||
|
forward_ready!(service);
|
||||||
|
|
||||||
|
fn call(&self, req: ServiceRequest) -> Self::Future {
|
||||||
|
// Extract request ID from headers or generate new one
|
||||||
|
let mut request_id = None;
|
||||||
|
|
||||||
|
for header_name in &self.headers {
|
||||||
|
if let Some(header_value) = req.headers().get(header_name) {
|
||||||
|
if let Ok(value) = header_value.to_str() {
|
||||||
|
request_id = Some(value.to_string());
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let request_id = request_id.unwrap_or_else(|| generate_request_id(req.path()));
|
||||||
|
|
||||||
|
// Insert request ID into request extensions
|
||||||
|
req.extensions_mut().insert(request_id);
|
||||||
|
|
||||||
|
let fut = self.service.call(req);
|
||||||
|
Box::pin(async move { fut.await })
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -66,7 +66,7 @@ use crate::tree::Tree;
|
|||||||
use std::sync::{Arc, Mutex};
|
use std::sync::{Arc, Mutex};
|
||||||
use std::thread;
|
use std::thread;
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
use tracing::{debug, info};
|
use tracing::debug;
|
||||||
|
|
||||||
/// Cache-aware routing policy
|
/// Cache-aware routing policy
|
||||||
///
|
///
|
||||||
@@ -164,10 +164,8 @@ impl LoadBalancingPolicy for CacheAwarePolicy {
|
|||||||
.map(|w| (w.url().to_string(), w.load()))
|
.map(|w| (w.url().to_string(), w.load()))
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
info!(
|
debug!(
|
||||||
"Load balancing triggered due to workload imbalance:\n\
|
"Load balancing triggered | max: {} | min: {} | workers: {:?}",
|
||||||
Max load: {}, Min load: {}\n\
|
|
||||||
Current worker loads: {:?}",
|
|
||||||
max_load, min_load, worker_loads
|
max_load, min_load, worker_loads
|
||||||
);
|
);
|
||||||
|
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ use super::pd_types::{api_path, Bootstrap, ChatReqInput, GenerateReqInput, PDRou
|
|||||||
use super::request_adapter::ToPdRequest;
|
use super::request_adapter::ToPdRequest;
|
||||||
use crate::core::{HealthChecker, Worker, WorkerFactory, WorkerLoadGuard};
|
use crate::core::{HealthChecker, Worker, WorkerFactory, WorkerLoadGuard};
|
||||||
use crate::metrics::RouterMetrics;
|
use crate::metrics::RouterMetrics;
|
||||||
|
use crate::middleware::get_request_id;
|
||||||
use crate::openai_api_types::{ChatCompletionRequest, CompletionRequest, GenerateRequest};
|
use crate::openai_api_types::{ChatCompletionRequest, CompletionRequest, GenerateRequest};
|
||||||
use crate::policies::LoadBalancingPolicy;
|
use crate::policies::LoadBalancingPolicy;
|
||||||
use crate::tree::Tree;
|
use crate::tree::Tree;
|
||||||
@@ -16,7 +17,6 @@ use std::collections::HashMap;
|
|||||||
use std::sync::{Arc, Mutex, RwLock};
|
use std::sync::{Arc, Mutex, RwLock};
|
||||||
use std::time::{Duration, Instant};
|
use std::time::{Duration, Instant};
|
||||||
use tracing::{debug, error, info, warn};
|
use tracing::{debug, error, info, warn};
|
||||||
use uuid::Uuid;
|
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct PDRouter {
|
pub struct PDRouter {
|
||||||
@@ -307,8 +307,8 @@ impl PDRouter {
|
|||||||
mut typed_req: GenerateReqInput,
|
mut typed_req: GenerateReqInput,
|
||||||
route: &str,
|
route: &str,
|
||||||
) -> HttpResponse {
|
) -> HttpResponse {
|
||||||
|
let request_id = get_request_id(req);
|
||||||
let start = Instant::now();
|
let start = Instant::now();
|
||||||
let _request_id = Uuid::new_v4();
|
|
||||||
|
|
||||||
// Get stream flag and return_logprob flag before moving the request
|
// Get stream flag and return_logprob flag before moving the request
|
||||||
let is_stream = typed_req.stream;
|
let is_stream = typed_req.stream;
|
||||||
@@ -328,7 +328,10 @@ impl PDRouter {
|
|||||||
let (prefill, decode) = match self.select_pd_pair(client, request_text).await {
|
let (prefill, decode) = match self.select_pd_pair(client, request_text).await {
|
||||||
Ok(pair) => pair,
|
Ok(pair) => pair,
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
error!("Failed to select PD pair: {}", e);
|
error!(
|
||||||
|
request_id = %request_id,
|
||||||
|
"Failed to select PD pair error={}", e
|
||||||
|
);
|
||||||
RouterMetrics::record_pd_error("server_selection");
|
RouterMetrics::record_pd_error("server_selection");
|
||||||
return HttpResponse::ServiceUnavailable()
|
return HttpResponse::ServiceUnavailable()
|
||||||
.body(format!("No available servers: {}", e));
|
.body(format!("No available servers: {}", e));
|
||||||
@@ -337,15 +340,17 @@ impl PDRouter {
|
|||||||
|
|
||||||
// Log routing decision
|
// Log routing decision
|
||||||
info!(
|
info!(
|
||||||
"PD routing: {} -> prefill={}, decode={}",
|
request_id = %request_id,
|
||||||
route,
|
"PD routing decision route={} prefill_url={} decode_url={}",
|
||||||
prefill.url(),
|
route, prefill.url(), decode.url()
|
||||||
decode.url()
|
|
||||||
);
|
);
|
||||||
|
|
||||||
// Add bootstrap info using the trait method
|
// Add bootstrap info using the trait method
|
||||||
if let Err(e) = typed_req.add_bootstrap_info(prefill.as_ref()) {
|
if let Err(e) = typed_req.add_bootstrap_info(prefill.as_ref()) {
|
||||||
error!("Failed to add bootstrap info: {}", e);
|
error!(
|
||||||
|
request_id = %request_id,
|
||||||
|
"Failed to add bootstrap info error={}", e
|
||||||
|
);
|
||||||
RouterMetrics::record_pd_error("bootstrap_injection");
|
RouterMetrics::record_pd_error("bootstrap_injection");
|
||||||
return HttpResponse::InternalServerError()
|
return HttpResponse::InternalServerError()
|
||||||
.body(format!("Bootstrap injection failed: {}", e));
|
.body(format!("Bootstrap injection failed: {}", e));
|
||||||
@@ -355,7 +360,10 @@ impl PDRouter {
|
|||||||
let json_with_bootstrap = match serde_json::to_value(&typed_req) {
|
let json_with_bootstrap = match serde_json::to_value(&typed_req) {
|
||||||
Ok(json) => json,
|
Ok(json) => json,
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
error!("Failed to serialize request: {}", e);
|
error!(
|
||||||
|
request_id = %request_id,
|
||||||
|
"Failed to serialize request error={}", e
|
||||||
|
);
|
||||||
return HttpResponse::InternalServerError().body("Failed to serialize request");
|
return HttpResponse::InternalServerError().body("Failed to serialize request");
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@@ -383,6 +391,7 @@ impl PDRouter {
|
|||||||
mut typed_req: ChatReqInput,
|
mut typed_req: ChatReqInput,
|
||||||
route: &str,
|
route: &str,
|
||||||
) -> HttpResponse {
|
) -> HttpResponse {
|
||||||
|
let request_id = get_request_id(req);
|
||||||
let start = Instant::now();
|
let start = Instant::now();
|
||||||
|
|
||||||
// Get stream flag and return_logprob flag before moving the request
|
// Get stream flag and return_logprob flag before moving the request
|
||||||
@@ -406,7 +415,10 @@ impl PDRouter {
|
|||||||
let (prefill, decode) = match self.select_pd_pair(client, request_text).await {
|
let (prefill, decode) = match self.select_pd_pair(client, request_text).await {
|
||||||
Ok(pair) => pair,
|
Ok(pair) => pair,
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
error!("Failed to select PD pair: {}", e);
|
error!(
|
||||||
|
request_id = %request_id,
|
||||||
|
"Failed to select PD pair error={}", e
|
||||||
|
);
|
||||||
RouterMetrics::record_pd_error("server_selection");
|
RouterMetrics::record_pd_error("server_selection");
|
||||||
return HttpResponse::ServiceUnavailable()
|
return HttpResponse::ServiceUnavailable()
|
||||||
.body(format!("No available servers: {}", e));
|
.body(format!("No available servers: {}", e));
|
||||||
@@ -415,15 +427,17 @@ impl PDRouter {
|
|||||||
|
|
||||||
// Log routing decision
|
// Log routing decision
|
||||||
info!(
|
info!(
|
||||||
"PD routing: {} -> prefill={}, decode={}",
|
request_id = %request_id,
|
||||||
route,
|
"PD routing decision route={} prefill_url={} decode_url={}",
|
||||||
prefill.url(),
|
route, prefill.url(), decode.url()
|
||||||
decode.url()
|
|
||||||
);
|
);
|
||||||
|
|
||||||
// Add bootstrap info using the trait method
|
// Add bootstrap info using the trait method
|
||||||
if let Err(e) = typed_req.add_bootstrap_info(prefill.as_ref()) {
|
if let Err(e) = typed_req.add_bootstrap_info(prefill.as_ref()) {
|
||||||
error!("Failed to add bootstrap info: {}", e);
|
error!(
|
||||||
|
request_id = %request_id,
|
||||||
|
"Failed to add bootstrap info error={}", e
|
||||||
|
);
|
||||||
RouterMetrics::record_pd_error("bootstrap_injection");
|
RouterMetrics::record_pd_error("bootstrap_injection");
|
||||||
return HttpResponse::InternalServerError()
|
return HttpResponse::InternalServerError()
|
||||||
.body(format!("Bootstrap injection failed: {}", e));
|
.body(format!("Bootstrap injection failed: {}", e));
|
||||||
@@ -433,7 +447,10 @@ impl PDRouter {
|
|||||||
let json_with_bootstrap = match serde_json::to_value(&typed_req) {
|
let json_with_bootstrap = match serde_json::to_value(&typed_req) {
|
||||||
Ok(json) => json,
|
Ok(json) => json,
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
error!("Failed to serialize request: {}", e);
|
error!(
|
||||||
|
request_id = %request_id,
|
||||||
|
"Failed to serialize request error={}", e
|
||||||
|
);
|
||||||
return HttpResponse::InternalServerError().body("Failed to serialize request");
|
return HttpResponse::InternalServerError().body("Failed to serialize request");
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@@ -461,6 +478,7 @@ impl PDRouter {
|
|||||||
mut typed_req: CompletionRequest,
|
mut typed_req: CompletionRequest,
|
||||||
route: &str,
|
route: &str,
|
||||||
) -> HttpResponse {
|
) -> HttpResponse {
|
||||||
|
let request_id = get_request_id(req);
|
||||||
let start = Instant::now();
|
let start = Instant::now();
|
||||||
|
|
||||||
// Get stream flag and return_logprob flag before moving the request
|
// Get stream flag and return_logprob flag before moving the request
|
||||||
@@ -477,7 +495,10 @@ impl PDRouter {
|
|||||||
let (prefill, decode) = match self.select_pd_pair(client, request_text).await {
|
let (prefill, decode) = match self.select_pd_pair(client, request_text).await {
|
||||||
Ok(pair) => pair,
|
Ok(pair) => pair,
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
error!("Failed to select PD pair: {}", e);
|
error!(
|
||||||
|
request_id = %request_id,
|
||||||
|
"Failed to select PD pair error={}", e
|
||||||
|
);
|
||||||
RouterMetrics::record_pd_error("server_selection");
|
RouterMetrics::record_pd_error("server_selection");
|
||||||
return HttpResponse::ServiceUnavailable()
|
return HttpResponse::ServiceUnavailable()
|
||||||
.body(format!("No available servers: {}", e));
|
.body(format!("No available servers: {}", e));
|
||||||
@@ -486,15 +507,17 @@ impl PDRouter {
|
|||||||
|
|
||||||
// Log routing decision
|
// Log routing decision
|
||||||
info!(
|
info!(
|
||||||
"PD routing: {} -> prefill={}, decode={}",
|
request_id = %request_id,
|
||||||
route,
|
"PD routing decision route={} prefill_url={} decode_url={}",
|
||||||
prefill.url(),
|
route, prefill.url(), decode.url()
|
||||||
decode.url()
|
|
||||||
);
|
);
|
||||||
|
|
||||||
// Add bootstrap info using the trait method
|
// Add bootstrap info using the trait method
|
||||||
if let Err(e) = typed_req.add_bootstrap_info(prefill.as_ref()) {
|
if let Err(e) = typed_req.add_bootstrap_info(prefill.as_ref()) {
|
||||||
error!("Failed to add bootstrap info: {}", e);
|
error!(
|
||||||
|
request_id = %request_id,
|
||||||
|
"Failed to add bootstrap info error={}", e
|
||||||
|
);
|
||||||
RouterMetrics::record_pd_error("bootstrap_injection");
|
RouterMetrics::record_pd_error("bootstrap_injection");
|
||||||
return HttpResponse::InternalServerError()
|
return HttpResponse::InternalServerError()
|
||||||
.body(format!("Bootstrap injection failed: {}", e));
|
.body(format!("Bootstrap injection failed: {}", e));
|
||||||
@@ -504,7 +527,10 @@ impl PDRouter {
|
|||||||
let json_with_bootstrap = match serde_json::to_value(&typed_req) {
|
let json_with_bootstrap = match serde_json::to_value(&typed_req) {
|
||||||
Ok(json) => json,
|
Ok(json) => json,
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
error!("Failed to serialize request: {}", e);
|
error!(
|
||||||
|
request_id = %request_id,
|
||||||
|
"Failed to serialize request error={}", e
|
||||||
|
);
|
||||||
return HttpResponse::InternalServerError().body("Failed to serialize request");
|
return HttpResponse::InternalServerError().body("Failed to serialize request");
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@@ -538,6 +564,7 @@ impl PDRouter {
|
|||||||
return_logprob: bool,
|
return_logprob: bool,
|
||||||
start_time: Instant,
|
start_time: Instant,
|
||||||
) -> HttpResponse {
|
) -> HttpResponse {
|
||||||
|
let request_id = get_request_id(req);
|
||||||
// Update load tracking for both workers
|
// Update load tracking for both workers
|
||||||
let _guard = WorkerLoadGuard::new_multi(vec![prefill, decode]);
|
let _guard = WorkerLoadGuard::new_multi(vec![prefill, decode]);
|
||||||
|
|
||||||
@@ -578,9 +605,9 @@ impl PDRouter {
|
|||||||
if !status.is_success() {
|
if !status.is_success() {
|
||||||
RouterMetrics::record_pd_decode_error(decode.url());
|
RouterMetrics::record_pd_decode_error(decode.url());
|
||||||
error!(
|
error!(
|
||||||
"Decode server {} returned error status: {}",
|
request_id = %request_id,
|
||||||
decode.url(),
|
"Decode server returned error status decode_url={} status={}",
|
||||||
status
|
decode.url(), status
|
||||||
);
|
);
|
||||||
|
|
||||||
// Return the error response from decode server
|
// Return the error response from decode server
|
||||||
@@ -598,9 +625,9 @@ impl PDRouter {
|
|||||||
// Log prefill errors for debugging
|
// Log prefill errors for debugging
|
||||||
if let Err(e) = &prefill_result {
|
if let Err(e) = &prefill_result {
|
||||||
error!(
|
error!(
|
||||||
"Prefill server {} failed (non-critical): {}",
|
request_id = %request_id,
|
||||||
prefill.url(),
|
"Prefill server failed (non-critical) prefill_url={} error={}",
|
||||||
e
|
prefill.url(), e
|
||||||
);
|
);
|
||||||
RouterMetrics::record_pd_prefill_error(prefill.url());
|
RouterMetrics::record_pd_prefill_error(prefill.url());
|
||||||
}
|
}
|
||||||
@@ -684,7 +711,12 @@ impl PDRouter {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
error!("Decode request failed: {}", e);
|
error!(
|
||||||
|
request_id = %request_id,
|
||||||
|
decode_url = %decode.url(),
|
||||||
|
error = %e,
|
||||||
|
"Decode request failed"
|
||||||
|
);
|
||||||
RouterMetrics::record_pd_decode_error(decode.url());
|
RouterMetrics::record_pd_decode_error(decode.url());
|
||||||
HttpResponse::BadGateway().body(format!("Decode server error: {}", e))
|
HttpResponse::BadGateway().body(format!("Decode server error: {}", e))
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
use crate::core::{HealthChecker, Worker, WorkerFactory};
|
use crate::core::{HealthChecker, Worker, WorkerFactory};
|
||||||
use crate::metrics::RouterMetrics;
|
use crate::metrics::RouterMetrics;
|
||||||
|
use crate::middleware::get_request_id;
|
||||||
use crate::policies::LoadBalancingPolicy;
|
use crate::policies::LoadBalancingPolicy;
|
||||||
use actix_web::http::header::{HeaderValue, CONTENT_TYPE};
|
use actix_web::http::header::{HeaderValue, CONTENT_TYPE};
|
||||||
use actix_web::{HttpRequest, HttpResponse};
|
use actix_web::{HttpRequest, HttpResponse};
|
||||||
@@ -134,32 +135,26 @@ impl Router {
|
|||||||
match sync_client.get(&format!("{}/health", url)).send() {
|
match sync_client.get(&format!("{}/health", url)).send() {
|
||||||
Ok(res) => {
|
Ok(res) => {
|
||||||
if !res.status().is_success() {
|
if !res.status().is_success() {
|
||||||
let msg = format!(
|
|
||||||
"Worker heatlh check is pending with status {}",
|
|
||||||
res.status()
|
|
||||||
);
|
|
||||||
info!("{}", msg);
|
|
||||||
all_healthy = false;
|
all_healthy = false;
|
||||||
unhealthy_workers.push((url, msg));
|
unhealthy_workers.push((url, format!("status: {}", res.status())));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Err(_) => {
|
Err(_) => {
|
||||||
let msg = format!("Worker is not ready yet");
|
|
||||||
info!("{}", msg);
|
|
||||||
all_healthy = false;
|
all_healthy = false;
|
||||||
unhealthy_workers.push((url, msg));
|
unhealthy_workers.push((url, "not ready".to_string()));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if all_healthy {
|
if all_healthy {
|
||||||
info!("All workers are healthy");
|
info!("All {} workers are healthy", worker_urls.len());
|
||||||
return Ok(());
|
return Ok(());
|
||||||
} else {
|
} else {
|
||||||
info!("Initializing workers:");
|
debug!(
|
||||||
for (url, reason) in &unhealthy_workers {
|
"Waiting for {} workers to become healthy ({} unhealthy)",
|
||||||
info!(" {} - {}", url, reason);
|
worker_urls.len(),
|
||||||
}
|
unhealthy_workers.len()
|
||||||
|
);
|
||||||
thread::sleep(Duration::from_secs(interval_secs));
|
thread::sleep(Duration::from_secs(interval_secs));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -181,6 +176,7 @@ impl Router {
|
|||||||
route: &str,
|
route: &str,
|
||||||
req: &HttpRequest,
|
req: &HttpRequest,
|
||||||
) -> HttpResponse {
|
) -> HttpResponse {
|
||||||
|
let request_id = get_request_id(req);
|
||||||
let start = Instant::now();
|
let start = Instant::now();
|
||||||
let mut request_builder = client.get(format!("{}{}", worker_url, route));
|
let mut request_builder = client.get(format!("{}{}", worker_url, route));
|
||||||
|
|
||||||
@@ -202,14 +198,32 @@ impl Router {
|
|||||||
|
|
||||||
match res.bytes().await {
|
match res.bytes().await {
|
||||||
Ok(body) => HttpResponse::build(status).body(body.to_vec()),
|
Ok(body) => HttpResponse::build(status).body(body.to_vec()),
|
||||||
Err(e) => HttpResponse::InternalServerError()
|
Err(e) => {
|
||||||
.body(format!("Failed to read response body: {}", e)),
|
error!(
|
||||||
|
request_id = %request_id,
|
||||||
|
worker_url = %worker_url,
|
||||||
|
route = %route,
|
||||||
|
error = %e,
|
||||||
|
"Failed to read response body"
|
||||||
|
);
|
||||||
|
HttpResponse::InternalServerError()
|
||||||
|
.body(format!("Failed to read response body: {}", e))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Err(e) => HttpResponse::InternalServerError().body(format!(
|
Err(e) => {
|
||||||
"Failed to send request to worker {}: {}",
|
error!(
|
||||||
worker_url, e
|
request_id = %request_id,
|
||||||
)),
|
worker_url = %worker_url,
|
||||||
|
route = %route,
|
||||||
|
error = %e,
|
||||||
|
"Failed to send request to worker"
|
||||||
|
);
|
||||||
|
HttpResponse::InternalServerError().body(format!(
|
||||||
|
"Failed to send request to worker {}: {}",
|
||||||
|
worker_url, e
|
||||||
|
))
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// Record request metrics
|
// Record request metrics
|
||||||
@@ -231,6 +245,7 @@ impl Router {
|
|||||||
route: &str,
|
route: &str,
|
||||||
req: &HttpRequest,
|
req: &HttpRequest,
|
||||||
) -> HttpResponse {
|
) -> HttpResponse {
|
||||||
|
let request_id = get_request_id(req);
|
||||||
const MAX_REQUEST_RETRIES: u32 = 3;
|
const MAX_REQUEST_RETRIES: u32 = 3;
|
||||||
const MAX_TOTAL_RETRIES: u32 = 6;
|
const MAX_TOTAL_RETRIES: u32 = 6;
|
||||||
let mut total_retries = 0;
|
let mut total_retries = 0;
|
||||||
@@ -260,17 +275,23 @@ impl Router {
|
|||||||
}
|
}
|
||||||
|
|
||||||
warn!(
|
warn!(
|
||||||
"Request to {} failed (attempt {}/{})",
|
request_id = %request_id,
|
||||||
worker_url,
|
route = %route,
|
||||||
request_retries + 1,
|
worker_url = %worker_url,
|
||||||
MAX_REQUEST_RETRIES
|
attempt = request_retries + 1,
|
||||||
|
max_attempts = MAX_REQUEST_RETRIES,
|
||||||
|
"Request failed"
|
||||||
);
|
);
|
||||||
|
|
||||||
request_retries += 1;
|
request_retries += 1;
|
||||||
total_retries += 1;
|
total_retries += 1;
|
||||||
|
|
||||||
if request_retries == MAX_REQUEST_RETRIES {
|
if request_retries == MAX_REQUEST_RETRIES {
|
||||||
warn!("Removing failed worker: {}", worker_url);
|
warn!(
|
||||||
|
request_id = %request_id,
|
||||||
|
worker_url = %worker_url,
|
||||||
|
"Removing failed worker"
|
||||||
|
);
|
||||||
self.remove_worker(&worker_url);
|
self.remove_worker(&worker_url);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
@@ -293,6 +314,7 @@ impl Router {
|
|||||||
typed_req: &T,
|
typed_req: &T,
|
||||||
route: &str,
|
route: &str,
|
||||||
) -> HttpResponse {
|
) -> HttpResponse {
|
||||||
|
let request_id = get_request_id(req);
|
||||||
// Handle retries like the original implementation
|
// Handle retries like the original implementation
|
||||||
let start = Instant::now();
|
let start = Instant::now();
|
||||||
const MAX_REQUEST_RETRIES: u32 = 3;
|
const MAX_REQUEST_RETRIES: u32 = 3;
|
||||||
@@ -357,17 +379,19 @@ impl Router {
|
|||||||
}
|
}
|
||||||
|
|
||||||
warn!(
|
warn!(
|
||||||
"Generate request to {} failed (attempt {}/{})",
|
request_id = %request_id,
|
||||||
worker_url,
|
"Generate request failed route={} worker_url={} attempt={} max_attempts={}",
|
||||||
request_retries + 1,
|
route, worker_url, request_retries + 1, MAX_REQUEST_RETRIES
|
||||||
MAX_REQUEST_RETRIES
|
|
||||||
);
|
);
|
||||||
|
|
||||||
request_retries += 1;
|
request_retries += 1;
|
||||||
total_retries += 1;
|
total_retries += 1;
|
||||||
|
|
||||||
if request_retries == MAX_REQUEST_RETRIES {
|
if request_retries == MAX_REQUEST_RETRIES {
|
||||||
warn!("Removing failed worker: {}", worker_url);
|
warn!(
|
||||||
|
request_id = %request_id,
|
||||||
|
"Removing failed worker after typed request failures worker_url={}", worker_url
|
||||||
|
);
|
||||||
self.remove_worker(&worker_url);
|
self.remove_worker(&worker_url);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
@@ -402,13 +426,9 @@ impl Router {
|
|||||||
is_stream: bool,
|
is_stream: bool,
|
||||||
load_incremented: bool, // Whether load was incremented for this request
|
load_incremented: bool, // Whether load was incremented for this request
|
||||||
) -> HttpResponse {
|
) -> HttpResponse {
|
||||||
|
let request_id = get_request_id(req);
|
||||||
let start = Instant::now();
|
let start = Instant::now();
|
||||||
|
|
||||||
// Debug: Log what we're sending
|
|
||||||
if let Ok(json_str) = serde_json::to_string_pretty(typed_req) {
|
|
||||||
debug!("Sending request to {}: {}", route, json_str);
|
|
||||||
}
|
|
||||||
|
|
||||||
let mut request_builder = client
|
let mut request_builder = client
|
||||||
.post(format!("{}{}", worker_url, route))
|
.post(format!("{}{}", worker_url, route))
|
||||||
.json(typed_req); // Use json() directly with typed request
|
.json(typed_req); // Use json() directly with typed request
|
||||||
@@ -424,7 +444,11 @@ impl Router {
|
|||||||
let res = match request_builder.send().await {
|
let res = match request_builder.send().await {
|
||||||
Ok(res) => res,
|
Ok(res) => res,
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
error!("Failed to send request to {}: {}", worker_url, e);
|
error!(
|
||||||
|
request_id = %request_id,
|
||||||
|
"Failed to send typed request worker_url={} route={} error={}",
|
||||||
|
worker_url, route, e
|
||||||
|
);
|
||||||
|
|
||||||
// Decrement load on error if it was incremented
|
// Decrement load on error if it was incremented
|
||||||
if load_incremented {
|
if load_incremented {
|
||||||
@@ -497,7 +521,6 @@ impl Router {
|
|||||||
&worker_url,
|
&worker_url,
|
||||||
worker.load(),
|
worker.load(),
|
||||||
);
|
);
|
||||||
debug!("Streaming is done!!")
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -536,7 +559,6 @@ impl Router {
|
|||||||
match client.get(&format!("{}/health", worker_url)).send().await {
|
match client.get(&format!("{}/health", worker_url)).send().await {
|
||||||
Ok(res) => {
|
Ok(res) => {
|
||||||
if res.status().is_success() {
|
if res.status().is_success() {
|
||||||
info!("Worker {} health check passed", worker_url);
|
|
||||||
let mut workers_guard = self.workers.write().unwrap();
|
let mut workers_guard = self.workers.write().unwrap();
|
||||||
if workers_guard.iter().any(|w| w.url() == worker_url) {
|
if workers_guard.iter().any(|w| w.url() == worker_url) {
|
||||||
return Err(format!("Worker {} already exists", worker_url));
|
return Err(format!("Worker {} already exists", worker_url));
|
||||||
@@ -560,8 +582,8 @@ impl Router {
|
|||||||
|
|
||||||
return Ok(format!("Successfully added worker: {}", worker_url));
|
return Ok(format!("Successfully added worker: {}", worker_url));
|
||||||
} else {
|
} else {
|
||||||
info!(
|
debug!(
|
||||||
"Worker {} health check is pending with status: {}.",
|
"Worker {} health check pending - status: {}",
|
||||||
worker_url,
|
worker_url,
|
||||||
res.status()
|
res.status()
|
||||||
);
|
);
|
||||||
@@ -576,10 +598,7 @@ impl Router {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
info!(
|
debug!("Worker {} health check pending - error: {}", worker_url, e);
|
||||||
"Worker {} health check is pending with error: {}",
|
|
||||||
worker_url, e
|
|
||||||
);
|
|
||||||
|
|
||||||
// if the url does not have http or https prefix, warn users
|
// if the url does not have http or https prefix, warn users
|
||||||
if !worker_url.starts_with("http://") && !worker_url.starts_with("https://") {
|
if !worker_url.starts_with("http://") && !worker_url.starts_with("https://") {
|
||||||
@@ -611,7 +630,6 @@ impl Router {
|
|||||||
.downcast_ref::<crate::policies::CacheAwarePolicy>()
|
.downcast_ref::<crate::policies::CacheAwarePolicy>()
|
||||||
{
|
{
|
||||||
cache_aware.remove_worker(worker_url);
|
cache_aware.remove_worker(worker_url);
|
||||||
info!("Removed worker from tree: {}", worker_url);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -675,7 +693,6 @@ impl Router {
|
|||||||
for url in &worker_urls {
|
for url in &worker_urls {
|
||||||
if let Some(load) = Self::get_worker_load_static(&client, url).await {
|
if let Some(load) = Self::get_worker_load_static(&client, url).await {
|
||||||
loads.insert(url.clone(), load);
|
loads.insert(url.clone(), load);
|
||||||
debug!("Worker {} load: {}", url, load);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
use crate::config::RouterConfig;
|
use crate::config::RouterConfig;
|
||||||
use crate::logging::{self, LoggingConfig};
|
use crate::logging::{self, LoggingConfig};
|
||||||
use crate::metrics::{self, PrometheusConfig};
|
use crate::metrics::{self, PrometheusConfig};
|
||||||
|
use crate::middleware::{get_request_id, RequestIdMiddleware};
|
||||||
use crate::openai_api_types::{ChatCompletionRequest, CompletionRequest, GenerateRequest};
|
use crate::openai_api_types::{ChatCompletionRequest, CompletionRequest, GenerateRequest};
|
||||||
use crate::routers::{RouterFactory, RouterTrait};
|
use crate::routers::{RouterFactory, RouterTrait};
|
||||||
use crate::service_discovery::{start_service_discovery, ServiceDiscoveryConfig};
|
use crate::service_discovery::{start_service_discovery, ServiceDiscoveryConfig};
|
||||||
@@ -46,13 +47,13 @@ async fn sink_handler(_req: HttpRequest, mut payload: web::Payload) -> Result<Ht
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Custom error handler for JSON payload errors.
|
// Custom error handler for JSON payload errors.
|
||||||
fn json_error_handler(err: error::JsonPayloadError, _req: &HttpRequest) -> Error {
|
fn json_error_handler(err: error::JsonPayloadError, req: &HttpRequest) -> Error {
|
||||||
error!("JSON payload error: {:?}", err);
|
let request_id = get_request_id(req);
|
||||||
match &err {
|
match &err {
|
||||||
error::JsonPayloadError::OverflowKnownLength { length, limit } => {
|
error::JsonPayloadError::OverflowKnownLength { length, limit } => {
|
||||||
error!(
|
error!(
|
||||||
"Payload too large: {} bytes exceeds limit of {} bytes",
|
request_id = %request_id,
|
||||||
length, limit
|
"Payload too large length={} limit={}", length, limit
|
||||||
);
|
);
|
||||||
error::ErrorPayloadTooLarge(format!(
|
error::ErrorPayloadTooLarge(format!(
|
||||||
"Payload too large: {} bytes exceeds limit of {} bytes",
|
"Payload too large: {} bytes exceeds limit of {} bytes",
|
||||||
@@ -60,10 +61,19 @@ fn json_error_handler(err: error::JsonPayloadError, _req: &HttpRequest) -> Error
|
|||||||
))
|
))
|
||||||
}
|
}
|
||||||
error::JsonPayloadError::Overflow { limit } => {
|
error::JsonPayloadError::Overflow { limit } => {
|
||||||
error!("Payload overflow: exceeds limit of {} bytes", limit);
|
error!(
|
||||||
|
request_id = %request_id,
|
||||||
|
"Payload overflow limit={}", limit
|
||||||
|
);
|
||||||
error::ErrorPayloadTooLarge(format!("Payload exceeds limit of {} bytes", limit))
|
error::ErrorPayloadTooLarge(format!("Payload exceeds limit of {} bytes", limit))
|
||||||
}
|
}
|
||||||
_ => error::ErrorBadRequest(format!("Invalid JSON payload: {}", err)),
|
_ => {
|
||||||
|
error!(
|
||||||
|
request_id = %request_id,
|
||||||
|
"Invalid JSON payload error={}", err
|
||||||
|
);
|
||||||
|
error::ErrorBadRequest(format!("Invalid JSON payload: {}", err))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -108,8 +118,20 @@ async fn generate(
|
|||||||
body: web::Json<GenerateRequest>,
|
body: web::Json<GenerateRequest>,
|
||||||
state: web::Data<AppState>,
|
state: web::Data<AppState>,
|
||||||
) -> Result<HttpResponse, Error> {
|
) -> Result<HttpResponse, Error> {
|
||||||
let json_body = serde_json::to_value(body.into_inner())
|
let request_id = get_request_id(&req);
|
||||||
.map_err(|e| error::ErrorBadRequest(format!("Invalid JSON: {}", e)))?;
|
info!(
|
||||||
|
request_id = %request_id,
|
||||||
|
"Received generate request method=\"POST\" path=\"/generate\""
|
||||||
|
);
|
||||||
|
|
||||||
|
let json_body = serde_json::to_value(body.into_inner()).map_err(|e| {
|
||||||
|
error!(
|
||||||
|
request_id = %request_id,
|
||||||
|
"Failed to parse generate request body error={}", e
|
||||||
|
);
|
||||||
|
error::ErrorBadRequest(format!("Invalid JSON: {}", e))
|
||||||
|
})?;
|
||||||
|
|
||||||
Ok(state
|
Ok(state
|
||||||
.router
|
.router
|
||||||
.route_generate(&state.client, &req, json_body)
|
.route_generate(&state.client, &req, json_body)
|
||||||
@@ -122,8 +144,20 @@ async fn v1_chat_completions(
|
|||||||
body: web::Json<ChatCompletionRequest>,
|
body: web::Json<ChatCompletionRequest>,
|
||||||
state: web::Data<AppState>,
|
state: web::Data<AppState>,
|
||||||
) -> Result<HttpResponse, Error> {
|
) -> Result<HttpResponse, Error> {
|
||||||
let json_body = serde_json::to_value(body.into_inner())
|
let request_id = get_request_id(&req);
|
||||||
.map_err(|e| error::ErrorBadRequest(format!("Invalid JSON: {}", e)))?;
|
info!(
|
||||||
|
request_id = %request_id,
|
||||||
|
"Received chat completion request method=\"POST\" path=\"/v1/chat/completions\""
|
||||||
|
);
|
||||||
|
|
||||||
|
let json_body = serde_json::to_value(body.into_inner()).map_err(|e| {
|
||||||
|
error!(
|
||||||
|
request_id = %request_id,
|
||||||
|
"Failed to parse chat completion request body error={}", e
|
||||||
|
);
|
||||||
|
error::ErrorBadRequest(format!("Invalid JSON: {}", e))
|
||||||
|
})?;
|
||||||
|
|
||||||
Ok(state
|
Ok(state
|
||||||
.router
|
.router
|
||||||
.route_chat(&state.client, &req, json_body)
|
.route_chat(&state.client, &req, json_body)
|
||||||
@@ -136,8 +170,20 @@ async fn v1_completions(
|
|||||||
body: web::Json<CompletionRequest>,
|
body: web::Json<CompletionRequest>,
|
||||||
state: web::Data<AppState>,
|
state: web::Data<AppState>,
|
||||||
) -> Result<HttpResponse, Error> {
|
) -> Result<HttpResponse, Error> {
|
||||||
let json_body = serde_json::to_value(body.into_inner())
|
let request_id = get_request_id(&req);
|
||||||
.map_err(|e| error::ErrorBadRequest(format!("Invalid JSON: {}", e)))?;
|
info!(
|
||||||
|
request_id = %request_id,
|
||||||
|
"Received completion request method=\"POST\" path=\"/v1/completions\""
|
||||||
|
);
|
||||||
|
|
||||||
|
let json_body = serde_json::to_value(body.into_inner()).map_err(|e| {
|
||||||
|
error!(
|
||||||
|
request_id = %request_id,
|
||||||
|
"Failed to parse completion request body error={}", e
|
||||||
|
);
|
||||||
|
error::ErrorBadRequest(format!("Invalid JSON: {}", e))
|
||||||
|
})?;
|
||||||
|
|
||||||
Ok(state
|
Ok(state
|
||||||
.router
|
.router
|
||||||
.route_completion(&state.client, &req, json_body)
|
.route_completion(&state.client, &req, json_body)
|
||||||
@@ -146,20 +192,48 @@ async fn v1_completions(
|
|||||||
|
|
||||||
#[post("/add_worker")]
|
#[post("/add_worker")]
|
||||||
async fn add_worker(
|
async fn add_worker(
|
||||||
|
req: HttpRequest,
|
||||||
query: web::Query<HashMap<String, String>>,
|
query: web::Query<HashMap<String, String>>,
|
||||||
data: web::Data<AppState>,
|
data: web::Data<AppState>,
|
||||||
) -> impl Responder {
|
) -> impl Responder {
|
||||||
|
let request_id = get_request_id(&req);
|
||||||
|
|
||||||
let worker_url = match query.get("url") {
|
let worker_url = match query.get("url") {
|
||||||
Some(url) => url.to_string(),
|
Some(url) => url.to_string(),
|
||||||
None => {
|
None => {
|
||||||
|
warn!(
|
||||||
|
request_id = %request_id,
|
||||||
|
"Add worker request missing URL parameter"
|
||||||
|
);
|
||||||
return HttpResponse::BadRequest()
|
return HttpResponse::BadRequest()
|
||||||
.body("Worker URL required. Provide 'url' query parameter")
|
.body("Worker URL required. Provide 'url' query parameter");
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
info!(
|
||||||
|
request_id = %request_id,
|
||||||
|
worker_url = %worker_url,
|
||||||
|
"Adding worker"
|
||||||
|
);
|
||||||
|
|
||||||
match data.router.add_worker(&worker_url).await {
|
match data.router.add_worker(&worker_url).await {
|
||||||
Ok(message) => HttpResponse::Ok().body(message),
|
Ok(message) => {
|
||||||
Err(error) => HttpResponse::BadRequest().body(error),
|
info!(
|
||||||
|
request_id = %request_id,
|
||||||
|
worker_url = %worker_url,
|
||||||
|
"Successfully added worker"
|
||||||
|
);
|
||||||
|
HttpResponse::Ok().body(message)
|
||||||
|
}
|
||||||
|
Err(error) => {
|
||||||
|
error!(
|
||||||
|
request_id = %request_id,
|
||||||
|
worker_url = %worker_url,
|
||||||
|
error = %error,
|
||||||
|
"Failed to add worker"
|
||||||
|
);
|
||||||
|
HttpResponse::BadRequest().body(error)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -171,13 +245,29 @@ async fn list_workers(data: web::Data<AppState>) -> impl Responder {
|
|||||||
|
|
||||||
#[post("/remove_worker")]
|
#[post("/remove_worker")]
|
||||||
async fn remove_worker(
|
async fn remove_worker(
|
||||||
|
req: HttpRequest,
|
||||||
query: web::Query<HashMap<String, String>>,
|
query: web::Query<HashMap<String, String>>,
|
||||||
data: web::Data<AppState>,
|
data: web::Data<AppState>,
|
||||||
) -> impl Responder {
|
) -> impl Responder {
|
||||||
|
let request_id = get_request_id(&req);
|
||||||
|
|
||||||
let worker_url = match query.get("url") {
|
let worker_url = match query.get("url") {
|
||||||
Some(url) => url.to_string(),
|
Some(url) => url.to_string(),
|
||||||
None => return HttpResponse::BadRequest().finish(),
|
None => {
|
||||||
|
warn!(
|
||||||
|
request_id = %request_id,
|
||||||
|
"Remove worker request missing URL parameter"
|
||||||
|
);
|
||||||
|
return HttpResponse::BadRequest().finish();
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
info!(
|
||||||
|
request_id = %request_id,
|
||||||
|
worker_url = %worker_url,
|
||||||
|
"Removing worker"
|
||||||
|
);
|
||||||
|
|
||||||
data.router.remove_worker(&worker_url);
|
data.router.remove_worker(&worker_url);
|
||||||
HttpResponse::Ok().body(format!("Successfully removed worker: {}", worker_url))
|
HttpResponse::Ok().body(format!("Successfully removed worker: {}", worker_url))
|
||||||
}
|
}
|
||||||
@@ -202,6 +292,7 @@ pub struct ServerConfig {
|
|||||||
pub service_discovery_config: Option<ServiceDiscoveryConfig>,
|
pub service_discovery_config: Option<ServiceDiscoveryConfig>,
|
||||||
pub prometheus_config: Option<PrometheusConfig>,
|
pub prometheus_config: Option<PrometheusConfig>,
|
||||||
pub request_timeout_secs: u64,
|
pub request_timeout_secs: u64,
|
||||||
|
pub request_id_headers: Option<Vec<String>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn startup(config: ServerConfig) -> std::io::Result<()> {
|
pub async fn startup(config: ServerConfig) -> std::io::Result<()> {
|
||||||
@@ -233,31 +324,18 @@ pub async fn startup(config: ServerConfig) -> std::io::Result<()> {
|
|||||||
|
|
||||||
// Initialize prometheus metrics exporter
|
// Initialize prometheus metrics exporter
|
||||||
if let Some(prometheus_config) = config.prometheus_config {
|
if let Some(prometheus_config) = config.prometheus_config {
|
||||||
info!(
|
|
||||||
"🚧 Initializing Prometheus metrics on {}:{}",
|
|
||||||
prometheus_config.host, prometheus_config.port
|
|
||||||
);
|
|
||||||
metrics::start_prometheus(prometheus_config);
|
metrics::start_prometheus(prometheus_config);
|
||||||
} else {
|
|
||||||
info!("🚧 Prometheus metrics disabled");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
info!("🚧 Initializing router on {}:{}", config.host, config.port);
|
|
||||||
info!("🚧 Router mode: {:?}", config.router_config.mode);
|
|
||||||
info!("🚧 Policy: {:?}", config.router_config.policy);
|
|
||||||
info!(
|
info!(
|
||||||
"🚧 Max payload size: {} MB",
|
"Starting router on {}:{} | mode: {:?} | policy: {:?} | max_payload: {}MB",
|
||||||
|
config.host,
|
||||||
|
config.port,
|
||||||
|
config.router_config.mode,
|
||||||
|
config.router_config.policy,
|
||||||
config.max_payload_size / (1024 * 1024)
|
config.max_payload_size / (1024 * 1024)
|
||||||
);
|
);
|
||||||
|
|
||||||
// Log service discovery status
|
|
||||||
if let Some(service_discovery_config) = &config.service_discovery_config {
|
|
||||||
info!("🚧 Service discovery enabled");
|
|
||||||
info!("🚧 Selector: {:?}", service_discovery_config.selector);
|
|
||||||
} else {
|
|
||||||
info!("🚧 Service discovery disabled");
|
|
||||||
}
|
|
||||||
|
|
||||||
let client = Client::builder()
|
let client = Client::builder()
|
||||||
.pool_idle_timeout(Some(Duration::from_secs(50)))
|
.pool_idle_timeout(Some(Duration::from_secs(50)))
|
||||||
.timeout(Duration::from_secs(config.request_timeout_secs)) // Use configurable timeout
|
.timeout(Duration::from_secs(config.request_timeout_secs)) // Use configurable timeout
|
||||||
@@ -272,11 +350,9 @@ pub async fn startup(config: ServerConfig) -> std::io::Result<()> {
|
|||||||
// Start the service discovery if enabled
|
// Start the service discovery if enabled
|
||||||
if let Some(service_discovery_config) = config.service_discovery_config {
|
if let Some(service_discovery_config) = config.service_discovery_config {
|
||||||
if service_discovery_config.enabled {
|
if service_discovery_config.enabled {
|
||||||
info!("🚧 Initializing Kubernetes service discovery");
|
|
||||||
// Pass the Arc<Router> directly
|
|
||||||
match start_service_discovery(service_discovery_config, router_arc).await {
|
match start_service_discovery(service_discovery_config, router_arc).await {
|
||||||
Ok(handle) => {
|
Ok(handle) => {
|
||||||
info!("✅ Service discovery started successfully");
|
info!("Service discovery started");
|
||||||
// Spawn a task to handle the service discovery thread
|
// Spawn a task to handle the service discovery thread
|
||||||
spawn(async move {
|
spawn(async move {
|
||||||
if let Err(e) = handle.await {
|
if let Err(e) = handle.await {
|
||||||
@@ -292,14 +368,26 @@ pub async fn startup(config: ServerConfig) -> std::io::Result<()> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
info!("✅ Serving router on {}:{}", config.host, config.port);
|
|
||||||
info!(
|
info!(
|
||||||
"✅ Serving workers on {:?}",
|
"Router ready | workers: {:?}",
|
||||||
app_state.router.get_worker_urls()
|
app_state.router.get_worker_urls()
|
||||||
);
|
);
|
||||||
|
|
||||||
|
// Configure request ID headers
|
||||||
|
let request_id_headers = config.request_id_headers.clone().unwrap_or_else(|| {
|
||||||
|
vec![
|
||||||
|
"x-request-id".to_string(),
|
||||||
|
"x-correlation-id".to_string(),
|
||||||
|
"x-trace-id".to_string(),
|
||||||
|
"request-id".to_string(),
|
||||||
|
]
|
||||||
|
});
|
||||||
|
|
||||||
HttpServer::new(move || {
|
HttpServer::new(move || {
|
||||||
|
let request_id_middleware = RequestIdMiddleware::new(request_id_headers.clone());
|
||||||
|
|
||||||
App::new()
|
App::new()
|
||||||
|
.wrap(request_id_middleware)
|
||||||
.app_data(app_state.clone())
|
.app_data(app_state.clone())
|
||||||
.app_data(
|
.app_data(
|
||||||
web::JsonConfig::default()
|
web::JsonConfig::default()
|
||||||
|
|||||||
@@ -209,7 +209,7 @@ pub async fn start_service_discovery(
|
|||||||
.join(",");
|
.join(",");
|
||||||
|
|
||||||
info!(
|
info!(
|
||||||
"Starting Kubernetes service discovery in PD mode with prefill_selector: '{}', decode_selector: '{}'",
|
"Starting K8s service discovery | PD mode | prefill: '{}' | decode: '{}'",
|
||||||
prefill_selector, decode_selector
|
prefill_selector, decode_selector
|
||||||
);
|
);
|
||||||
} else {
|
} else {
|
||||||
@@ -221,7 +221,7 @@ pub async fn start_service_discovery(
|
|||||||
.join(",");
|
.join(",");
|
||||||
|
|
||||||
info!(
|
info!(
|
||||||
"Starting Kubernetes service discovery with selector: '{}'",
|
"Starting K8s service discovery | selector: '{}'",
|
||||||
label_selector
|
label_selector
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
@@ -238,7 +238,7 @@ pub async fn start_service_discovery(
|
|||||||
Api::all(client)
|
Api::all(client)
|
||||||
};
|
};
|
||||||
|
|
||||||
info!("Kubernetes service discovery initialized successfully");
|
debug!("K8s service discovery initialized");
|
||||||
|
|
||||||
// Create Arcs for configuration data
|
// Create Arcs for configuration data
|
||||||
let config_arc = Arc::new(config.clone());
|
let config_arc = Arc::new(config.clone());
|
||||||
@@ -375,7 +375,7 @@ async fn handle_pod_event(
|
|||||||
|
|
||||||
if should_add {
|
if should_add {
|
||||||
info!(
|
info!(
|
||||||
"Healthy pod found: {} (type: {:?}). Adding worker: {}",
|
"Adding pod: {} | type: {:?} | url: {}",
|
||||||
pod_info.name, pod_info.pod_type, worker_url
|
pod_info.name, pod_info.pod_type, worker_url
|
||||||
);
|
);
|
||||||
|
|
||||||
@@ -409,8 +409,8 @@ async fn handle_pod_event(
|
|||||||
};
|
};
|
||||||
|
|
||||||
match result {
|
match result {
|
||||||
Ok(msg) => {
|
Ok(_) => {
|
||||||
info!("Successfully added worker: {}", msg);
|
debug!("Worker added: {}", worker_url);
|
||||||
}
|
}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
error!("Failed to add worker {} to router: {}", worker_url, e);
|
error!("Failed to add worker {} to router: {}", worker_url, e);
|
||||||
@@ -446,7 +446,7 @@ async fn handle_pod_deletion(
|
|||||||
|
|
||||||
if was_tracked {
|
if was_tracked {
|
||||||
info!(
|
info!(
|
||||||
"Pod deleted: {} (type: {:?}). Removing worker: {}",
|
"Removing pod: {} | type: {:?} | url: {}",
|
||||||
pod_info.name, pod_info.pod_type, worker_url
|
pod_info.name, pod_info.pod_type, worker_url
|
||||||
);
|
);
|
||||||
|
|
||||||
|
|||||||
@@ -35,6 +35,7 @@ impl TestContext {
|
|||||||
metrics: None,
|
metrics: None,
|
||||||
log_dir: None,
|
log_dir: None,
|
||||||
log_level: None,
|
log_level: None,
|
||||||
|
request_id_headers: None,
|
||||||
};
|
};
|
||||||
|
|
||||||
Self::new_with_config(config, worker_configs).await
|
Self::new_with_config(config, worker_configs).await
|
||||||
@@ -953,6 +954,7 @@ mod error_tests {
|
|||||||
metrics: None,
|
metrics: None,
|
||||||
log_dir: None,
|
log_dir: None,
|
||||||
log_level: None,
|
log_level: None,
|
||||||
|
request_id_headers: None,
|
||||||
};
|
};
|
||||||
|
|
||||||
let ctx = TestContext::new_with_config(
|
let ctx = TestContext::new_with_config(
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ pub fn create_test_config(worker_urls: Vec<String>) -> RouterConfig {
|
|||||||
metrics: None,
|
metrics: None,
|
||||||
log_dir: None,
|
log_dir: None,
|
||||||
log_level: None,
|
log_level: None,
|
||||||
|
request_id_headers: None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -40,6 +41,7 @@ pub fn create_test_config_no_workers() -> RouterConfig {
|
|||||||
metrics: None,
|
metrics: None,
|
||||||
log_dir: None,
|
log_dir: None,
|
||||||
log_level: None,
|
log_level: None,
|
||||||
|
request_id_headers: None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -46,6 +46,7 @@ impl RequestTestContext {
|
|||||||
metrics: None,
|
metrics: None,
|
||||||
log_dir: None,
|
log_dir: None,
|
||||||
log_level: None,
|
log_level: None,
|
||||||
|
request_id_headers: None,
|
||||||
};
|
};
|
||||||
|
|
||||||
let client = Client::builder()
|
let client = Client::builder()
|
||||||
|
|||||||
@@ -50,6 +50,7 @@ impl StreamingTestContext {
|
|||||||
metrics: None,
|
metrics: None,
|
||||||
log_dir: None,
|
log_dir: None,
|
||||||
log_level: None,
|
log_level: None,
|
||||||
|
request_id_headers: None,
|
||||||
};
|
};
|
||||||
|
|
||||||
let client = Client::builder()
|
let client = Client::builder()
|
||||||
|
|||||||
@@ -173,6 +173,7 @@ mod test_pd_routing {
|
|||||||
metrics: None,
|
metrics: None,
|
||||||
log_dir: None,
|
log_dir: None,
|
||||||
log_level: None,
|
log_level: None,
|
||||||
|
request_id_headers: None,
|
||||||
};
|
};
|
||||||
|
|
||||||
// Router creation will fail due to health checks, but config should be valid
|
// Router creation will fail due to health checks, but config should be valid
|
||||||
|
|||||||
Reference in New Issue
Block a user