[router] consolidate health endpoints and flush cache (#10876)
This commit is contained in:
@@ -12,7 +12,7 @@ use crate::core::{
|
|||||||
Worker, WorkerFactory, WorkerRegistry, WorkerType,
|
Worker, WorkerFactory, WorkerRegistry, WorkerType,
|
||||||
};
|
};
|
||||||
use crate::policies::PolicyRegistry;
|
use crate::policies::PolicyRegistry;
|
||||||
use crate::protocols::worker_spec::WorkerConfigRequest;
|
use crate::protocols::worker_spec::{FlushCacheResult, WorkerConfigRequest};
|
||||||
use crate::server::AppContext;
|
use crate::server::AppContext;
|
||||||
use futures::future;
|
use futures::future;
|
||||||
use once_cell::sync::Lazy;
|
use once_cell::sync::Lazy;
|
||||||
@@ -981,6 +981,104 @@ impl WorkerManager {
|
|||||||
success_threshold: config.success_threshold,
|
success_threshold: config.success_threshold,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
/// Flush cache on all workers
|
||||||
|
///
|
||||||
|
/// Sends a POST request to /flush_cache endpoint on all HTTP workers.
|
||||||
|
/// Returns detailed results showing which workers succeeded and which failed.
|
||||||
|
pub async fn flush_cache_all(
|
||||||
|
worker_registry: &WorkerRegistry,
|
||||||
|
client: &reqwest::Client,
|
||||||
|
) -> Result<FlushCacheResult, String> {
|
||||||
|
warn!("Flushing cache for ALL workers - this may impact performance temporarily");
|
||||||
|
|
||||||
|
let workers = worker_registry.get_all();
|
||||||
|
|
||||||
|
let http_workers: Vec<_> = workers
|
||||||
|
.iter()
|
||||||
|
.filter(|w| matches!(w.connection_mode(), ConnectionMode::Http))
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
if http_workers.is_empty() {
|
||||||
|
return Ok(FlushCacheResult {
|
||||||
|
successful: vec![],
|
||||||
|
failed: vec![],
|
||||||
|
total_workers: workers.len(),
|
||||||
|
http_workers: 0,
|
||||||
|
message: "No HTTP workers available for cache flush".to_string(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
info!(
|
||||||
|
"Flushing cache on {} HTTP workers (out of {} total workers)",
|
||||||
|
http_workers.len(),
|
||||||
|
workers.len()
|
||||||
|
);
|
||||||
|
|
||||||
|
let mut tasks = Vec::new();
|
||||||
|
for worker in &http_workers {
|
||||||
|
let url = worker.url().to_string();
|
||||||
|
let flush_url = format!("{}/flush_cache", url);
|
||||||
|
let mut request = client.post(&flush_url);
|
||||||
|
|
||||||
|
if let Some(api_key) = worker.api_key() {
|
||||||
|
request = request.header("Authorization", format!("Bearer {}", api_key));
|
||||||
|
}
|
||||||
|
|
||||||
|
let worker_url = url.clone();
|
||||||
|
tasks.push(async move {
|
||||||
|
let result = request.send().await;
|
||||||
|
(worker_url, result)
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
let results = futures::future::join_all(tasks).await;
|
||||||
|
|
||||||
|
let mut successful = Vec::new();
|
||||||
|
let mut failed = Vec::new();
|
||||||
|
|
||||||
|
for (url, result) in results {
|
||||||
|
match result {
|
||||||
|
Ok(response) if response.status().is_success() => {
|
||||||
|
debug!("Successfully flushed cache on worker: {}", url);
|
||||||
|
successful.push(url);
|
||||||
|
}
|
||||||
|
Ok(response) => {
|
||||||
|
let error = format!("HTTP {}", response.status());
|
||||||
|
warn!("Failed to flush cache on worker {}: {}", url, error);
|
||||||
|
failed.push((url, error));
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
let error = e.to_string();
|
||||||
|
error!("Failed to connect to worker {}: {}", url, error);
|
||||||
|
failed.push((url, error));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let message = if failed.is_empty() {
|
||||||
|
format!(
|
||||||
|
"Successfully flushed cache on all {} HTTP workers",
|
||||||
|
successful.len()
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
format!(
|
||||||
|
"Cache flush completed: {} succeeded, {} failed (out of {} HTTP workers)",
|
||||||
|
successful.len(),
|
||||||
|
failed.len(),
|
||||||
|
http_workers.len()
|
||||||
|
)
|
||||||
|
};
|
||||||
|
|
||||||
|
info!("{}", message);
|
||||||
|
|
||||||
|
Ok(FlushCacheResult {
|
||||||
|
successful,
|
||||||
|
failed,
|
||||||
|
total_workers: workers.len(),
|
||||||
|
http_workers: http_workers.len(),
|
||||||
|
message,
|
||||||
|
})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
|
|||||||
@@ -200,3 +200,18 @@ pub struct ServerInfo {
|
|||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
pub chat_template: Option<String>,
|
pub chat_template: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Result from flush cache operations across workers
|
||||||
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
|
pub struct FlushCacheResult {
|
||||||
|
/// URLs of workers where cache flush succeeded
|
||||||
|
pub successful: Vec<String>,
|
||||||
|
/// URLs and error messages for workers where cache flush failed
|
||||||
|
pub failed: Vec<(String, String)>,
|
||||||
|
/// Total number of workers attempted
|
||||||
|
pub total_workers: usize,
|
||||||
|
/// Number of HTTP workers (gRPC workers don't support flush cache)
|
||||||
|
pub http_workers: usize,
|
||||||
|
/// Human-readable summary message
|
||||||
|
pub message: String,
|
||||||
|
}
|
||||||
|
|||||||
@@ -252,12 +252,13 @@ impl RouterTrait for GrpcPDRouter {
|
|||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn health(&self, _req: Request<Body>) -> Response {
|
|
||||||
(StatusCode::NOT_IMPLEMENTED).into_response()
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn health_generate(&self, _req: Request<Body>) -> Response {
|
async fn health_generate(&self, _req: Request<Body>) -> Response {
|
||||||
(StatusCode::NOT_IMPLEMENTED).into_response()
|
// TODO: Implement actual generation test for gRPC PD mode
|
||||||
|
(
|
||||||
|
StatusCode::NOT_IMPLEMENTED,
|
||||||
|
"Health generate not yet implemented for gRPC PD",
|
||||||
|
)
|
||||||
|
.into_response()
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn get_server_info(&self, _req: Request<Body>) -> Response {
|
async fn get_server_info(&self, _req: Request<Body>) -> Response {
|
||||||
@@ -339,10 +340,6 @@ impl RouterTrait for GrpcPDRouter {
|
|||||||
(StatusCode::NOT_IMPLEMENTED).into_response()
|
(StatusCode::NOT_IMPLEMENTED).into_response()
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn flush_cache(&self) -> Response {
|
|
||||||
(StatusCode::NOT_IMPLEMENTED).into_response()
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn get_worker_loads(&self) -> Response {
|
async fn get_worker_loads(&self) -> Response {
|
||||||
(StatusCode::NOT_IMPLEMENTED).into_response()
|
(StatusCode::NOT_IMPLEMENTED).into_response()
|
||||||
}
|
}
|
||||||
@@ -350,8 +347,4 @@ impl RouterTrait for GrpcPDRouter {
|
|||||||
fn router_type(&self) -> &'static str {
|
fn router_type(&self) -> &'static str {
|
||||||
"grpc_pd"
|
"grpc_pd"
|
||||||
}
|
}
|
||||||
|
|
||||||
fn readiness(&self) -> Response {
|
|
||||||
(StatusCode::SERVICE_UNAVAILABLE).into_response()
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -699,12 +699,13 @@ impl RouterTrait for GrpcRouter {
|
|||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn health(&self, _req: Request<Body>) -> Response {
|
|
||||||
(StatusCode::NOT_IMPLEMENTED).into_response()
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn health_generate(&self, _req: Request<Body>) -> Response {
|
async fn health_generate(&self, _req: Request<Body>) -> Response {
|
||||||
(StatusCode::NOT_IMPLEMENTED).into_response()
|
// TODO: Implement actual generation test for gRPC
|
||||||
|
(
|
||||||
|
StatusCode::NOT_IMPLEMENTED,
|
||||||
|
"Health generate not yet implemented for gRPC",
|
||||||
|
)
|
||||||
|
.into_response()
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn get_server_info(&self, _req: Request<Body>) -> Response {
|
async fn get_server_info(&self, _req: Request<Body>) -> Response {
|
||||||
@@ -786,10 +787,6 @@ impl RouterTrait for GrpcRouter {
|
|||||||
(StatusCode::NOT_IMPLEMENTED).into_response()
|
(StatusCode::NOT_IMPLEMENTED).into_response()
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn flush_cache(&self) -> Response {
|
|
||||||
(StatusCode::NOT_IMPLEMENTED).into_response()
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn get_worker_loads(&self) -> Response {
|
async fn get_worker_loads(&self) -> Response {
|
||||||
(StatusCode::NOT_IMPLEMENTED).into_response()
|
(StatusCode::NOT_IMPLEMENTED).into_response()
|
||||||
}
|
}
|
||||||
@@ -797,10 +794,6 @@ impl RouterTrait for GrpcRouter {
|
|||||||
fn router_type(&self) -> &'static str {
|
fn router_type(&self) -> &'static str {
|
||||||
"grpc"
|
"grpc"
|
||||||
}
|
}
|
||||||
|
|
||||||
fn readiness(&self) -> Response {
|
|
||||||
(StatusCode::SERVICE_UNAVAILABLE).into_response()
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
|
|||||||
@@ -20,13 +20,7 @@ use axum::{
|
|||||||
use bytes::Bytes;
|
use bytes::Bytes;
|
||||||
use futures_util::StreamExt;
|
use futures_util::StreamExt;
|
||||||
use serde_json::{json, to_value, Value};
|
use serde_json::{json, to_value, Value};
|
||||||
use std::{
|
use std::{any::Any, borrow::Cow, collections::HashMap, io, sync::atomic::AtomicBool};
|
||||||
any::Any,
|
|
||||||
borrow::Cow,
|
|
||||||
collections::HashMap,
|
|
||||||
io,
|
|
||||||
sync::atomic::{AtomicBool, Ordering},
|
|
||||||
};
|
|
||||||
use tokio::sync::mpsc;
|
use tokio::sync::mpsc;
|
||||||
use tokio_stream::wrappers::UnboundedReceiverStream;
|
use tokio_stream::wrappers::UnboundedReceiverStream;
|
||||||
use tracing::{error, info, warn};
|
use tracing::{error, info, warn};
|
||||||
@@ -777,7 +771,7 @@ impl super::super::RouterTrait for OpenAIRouter {
|
|||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn health(&self, _req: Request<Body>) -> Response {
|
async fn health_generate(&self, _req: Request<Body>) -> Response {
|
||||||
// Simple upstream probe: GET {base}/v1/models without auth
|
// Simple upstream probe: GET {base}/v1/models without auth
|
||||||
let url = format!("{}/v1/models", self.base_url);
|
let url = format!("{}/v1/models", self.base_url);
|
||||||
match self
|
match self
|
||||||
@@ -808,11 +802,6 @@ impl super::super::RouterTrait for OpenAIRouter {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn health_generate(&self, _req: Request<Body>) -> Response {
|
|
||||||
// For OpenAI, health_generate is the same as health
|
|
||||||
self.health(_req).await
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn get_server_info(&self, _req: Request<Body>) -> Response {
|
async fn get_server_info(&self, _req: Request<Body>) -> Response {
|
||||||
let info = json!({
|
let info = json!({
|
||||||
"router_type": "openai",
|
"router_type": "openai",
|
||||||
@@ -1307,14 +1296,6 @@ impl super::super::RouterTrait for OpenAIRouter {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn flush_cache(&self) -> Response {
|
|
||||||
(
|
|
||||||
StatusCode::FORBIDDEN,
|
|
||||||
"flush_cache not supported for OpenAI router",
|
|
||||||
)
|
|
||||||
.into_response()
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn get_worker_loads(&self) -> Response {
|
async fn get_worker_loads(&self) -> Response {
|
||||||
(
|
(
|
||||||
StatusCode::FORBIDDEN,
|
StatusCode::FORBIDDEN,
|
||||||
@@ -1327,14 +1308,6 @@ impl super::super::RouterTrait for OpenAIRouter {
|
|||||||
"openai"
|
"openai"
|
||||||
}
|
}
|
||||||
|
|
||||||
fn readiness(&self) -> Response {
|
|
||||||
if self.healthy.load(Ordering::Acquire) && self.circuit_breaker.can_execute() {
|
|
||||||
(StatusCode::OK, "Ready").into_response()
|
|
||||||
} else {
|
|
||||||
(StatusCode::SERVICE_UNAVAILABLE, "Not ready").into_response()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn route_embeddings(
|
async fn route_embeddings(
|
||||||
&self,
|
&self,
|
||||||
_headers: Option<&HeaderMap>,
|
_headers: Option<&HeaderMap>,
|
||||||
|
|||||||
@@ -53,41 +53,6 @@ struct PDRequestContext<'a> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl PDRouter {
|
impl PDRouter {
|
||||||
async fn process_workers(
|
|
||||||
&self,
|
|
||||||
worker_type_enum: WorkerType,
|
|
||||||
worker_type: &str,
|
|
||||||
endpoint: &str,
|
|
||||||
) -> (Vec<String>, Vec<String>) {
|
|
||||||
let mut results = Vec::new();
|
|
||||||
let mut errors = Vec::new();
|
|
||||||
|
|
||||||
let workers = self.worker_registry.get_by_type(&worker_type_enum);
|
|
||||||
let urls: Vec<String> = workers.iter().map(|w| w.url().to_string()).collect();
|
|
||||||
|
|
||||||
for worker_url in urls {
|
|
||||||
let url = format!("{}/{}", worker_url, endpoint);
|
|
||||||
match self.client.post(&url).send().await {
|
|
||||||
Ok(res) if res.status().is_success() => {
|
|
||||||
results.push(format!("{} {}: OK", worker_type, worker_url));
|
|
||||||
}
|
|
||||||
Ok(res) => {
|
|
||||||
errors.push(format!(
|
|
||||||
"{} {} returned status: {}",
|
|
||||||
worker_type,
|
|
||||||
worker_url,
|
|
||||||
res.status()
|
|
||||||
));
|
|
||||||
}
|
|
||||||
Err(e) => {
|
|
||||||
errors.push(format!("{} {} error: {}", worker_type, worker_url, e));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
(results, errors)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn _get_worker_url_and_key(&self, w: &Arc<dyn Worker>) -> (String, Option<String>) {
|
fn _get_worker_url_and_key(&self, w: &Arc<dyn Worker>) -> (String, Option<String>) {
|
||||||
(w.url().to_string(), w.api_key().clone())
|
(w.url().to_string(), w.api_key().clone())
|
||||||
}
|
}
|
||||||
@@ -1167,36 +1132,6 @@ impl RouterTrait for PDRouter {
|
|||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn health(&self, _req: Request<Body>) -> Response {
|
|
||||||
// This is a server readiness check - checking if we have healthy workers
|
|
||||||
// Workers handle their own health checks in the background
|
|
||||||
let mut all_healthy = true;
|
|
||||||
let mut unhealthy_servers = Vec::new();
|
|
||||||
|
|
||||||
// Check all workers
|
|
||||||
for worker in self.worker_registry.get_all() {
|
|
||||||
if !worker.is_healthy() {
|
|
||||||
all_healthy = false;
|
|
||||||
let worker_type = match worker.worker_type() {
|
|
||||||
WorkerType::Prefill { .. } => "Prefill",
|
|
||||||
WorkerType::Decode => "Decode",
|
|
||||||
_ => "Worker",
|
|
||||||
};
|
|
||||||
unhealthy_servers.push(format!("{}: {}", worker_type, worker.url()));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if all_healthy {
|
|
||||||
(StatusCode::OK, "All servers healthy").into_response()
|
|
||||||
} else {
|
|
||||||
(
|
|
||||||
StatusCode::SERVICE_UNAVAILABLE,
|
|
||||||
format!("Unhealthy servers: {:?}", unhealthy_servers),
|
|
||||||
)
|
|
||||||
.into_response()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn health_generate(&self, _req: Request<Body>) -> Response {
|
async fn health_generate(&self, _req: Request<Body>) -> Response {
|
||||||
// Test model generation capability by selecting a random pair and testing them
|
// Test model generation capability by selecting a random pair and testing them
|
||||||
// Note: This endpoint actually causes the model to generate tokens, so we only test one pair
|
// Note: This endpoint actually causes the model to generate tokens, so we only test one pair
|
||||||
@@ -1483,45 +1418,6 @@ impl RouterTrait for PDRouter {
|
|||||||
self.execute_dual_dispatch(headers, body, context).await
|
self.execute_dual_dispatch(headers, body, context).await
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn flush_cache(&self) -> Response {
|
|
||||||
// Process both prefill and decode workers
|
|
||||||
let (prefill_results, prefill_errors) = self
|
|
||||||
.process_workers(
|
|
||||||
WorkerType::Prefill {
|
|
||||||
bootstrap_port: None,
|
|
||||||
},
|
|
||||||
"Prefill",
|
|
||||||
"flush_cache",
|
|
||||||
)
|
|
||||||
.await;
|
|
||||||
let (decode_results, decode_errors) = self
|
|
||||||
.process_workers(WorkerType::Decode, "Decode", "flush_cache")
|
|
||||||
.await;
|
|
||||||
|
|
||||||
// Combine results and errors
|
|
||||||
let mut results = prefill_results;
|
|
||||||
results.extend(decode_results);
|
|
||||||
let mut errors = prefill_errors;
|
|
||||||
errors.extend(decode_errors);
|
|
||||||
|
|
||||||
if errors.is_empty() {
|
|
||||||
(
|
|
||||||
StatusCode::OK,
|
|
||||||
format!("Cache flushed successfully: {:?}", results),
|
|
||||||
)
|
|
||||||
.into_response()
|
|
||||||
} else {
|
|
||||||
(
|
|
||||||
StatusCode::PARTIAL_CONTENT,
|
|
||||||
format!(
|
|
||||||
"Partial success. Results: {:?}, Errors: {:?}",
|
|
||||||
results, errors
|
|
||||||
),
|
|
||||||
)
|
|
||||||
.into_response()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn get_worker_loads(&self) -> Response {
|
async fn get_worker_loads(&self) -> Response {
|
||||||
let mut loads = HashMap::new();
|
let mut loads = HashMap::new();
|
||||||
let mut errors = Vec::new();
|
let mut errors = Vec::new();
|
||||||
@@ -1563,59 +1459,6 @@ impl RouterTrait for PDRouter {
|
|||||||
fn router_type(&self) -> &'static str {
|
fn router_type(&self) -> &'static str {
|
||||||
"pd"
|
"pd"
|
||||||
}
|
}
|
||||||
|
|
||||||
fn readiness(&self) -> Response {
|
|
||||||
// PD router is ready if it has at least one healthy prefill AND one healthy decode worker
|
|
||||||
let prefill_workers = self.worker_registry.get_prefill_workers();
|
|
||||||
let decode_workers = self.worker_registry.get_decode_workers();
|
|
||||||
|
|
||||||
let healthy_prefill_count = prefill_workers.iter().filter(|w| w.is_healthy()).count();
|
|
||||||
|
|
||||||
let healthy_decode_count = decode_workers.iter().filter(|w| w.is_healthy()).count();
|
|
||||||
|
|
||||||
let total_prefill = prefill_workers.len();
|
|
||||||
let total_decode = decode_workers.len();
|
|
||||||
|
|
||||||
if healthy_prefill_count > 0 && healthy_decode_count > 0 {
|
|
||||||
Json(json!({
|
|
||||||
"status": "ready",
|
|
||||||
"prefill": {
|
|
||||||
"healthy": healthy_prefill_count,
|
|
||||||
"total": total_prefill
|
|
||||||
},
|
|
||||||
"decode": {
|
|
||||||
"healthy": healthy_decode_count,
|
|
||||||
"total": total_decode
|
|
||||||
}
|
|
||||||
}))
|
|
||||||
.into_response()
|
|
||||||
} else {
|
|
||||||
let mut reasons = Vec::new();
|
|
||||||
if healthy_prefill_count == 0 {
|
|
||||||
reasons.push("no healthy prefill workers");
|
|
||||||
}
|
|
||||||
if healthy_decode_count == 0 {
|
|
||||||
reasons.push("no healthy decode workers");
|
|
||||||
}
|
|
||||||
|
|
||||||
(
|
|
||||||
StatusCode::SERVICE_UNAVAILABLE,
|
|
||||||
Json(serde_json::json!({
|
|
||||||
"status": "not_ready",
|
|
||||||
"reason": reasons.join(", "),
|
|
||||||
"prefill": {
|
|
||||||
"healthy": healthy_prefill_count,
|
|
||||||
"total": total_prefill
|
|
||||||
},
|
|
||||||
"decode": {
|
|
||||||
"healthy": healthy_decode_count,
|
|
||||||
"total": total_decode
|
|
||||||
}
|
|
||||||
})),
|
|
||||||
)
|
|
||||||
.into_response()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
@@ -1692,37 +1535,6 @@ mod tests {
|
|||||||
assert!(result.unwrap_err().contains("No prefill workers available"));
|
assert!(result.unwrap_err().contains("No prefill workers available"));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn test_health_endpoints() {
|
|
||||||
let router = create_test_pd_router();
|
|
||||||
|
|
||||||
let prefill_worker = create_test_worker(
|
|
||||||
"http://localhost:8000".to_string(),
|
|
||||||
WorkerType::Prefill {
|
|
||||||
bootstrap_port: None,
|
|
||||||
},
|
|
||||||
true,
|
|
||||||
);
|
|
||||||
let decode_worker = create_test_worker(
|
|
||||||
"http://localhost:8001".to_string(),
|
|
||||||
WorkerType::Decode,
|
|
||||||
true,
|
|
||||||
);
|
|
||||||
|
|
||||||
router.worker_registry.register(Arc::from(prefill_worker));
|
|
||||||
router.worker_registry.register(Arc::from(decode_worker));
|
|
||||||
|
|
||||||
let http_req = axum::http::Request::builder()
|
|
||||||
.body(axum::body::Body::empty())
|
|
||||||
.unwrap();
|
|
||||||
let response = router.health(http_req).await;
|
|
||||||
|
|
||||||
assert_eq!(response.status(), 200);
|
|
||||||
|
|
||||||
let response = router.readiness();
|
|
||||||
assert_eq!(response.status(), 200);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_load_monitor_updates() {
|
async fn test_load_monitor_updates() {
|
||||||
let power_of_two_policy = Arc::new(crate::policies::PowerOfTwoPolicy::new());
|
let power_of_two_policy = Arc::new(crate::policies::PowerOfTwoPolicy::new());
|
||||||
|
|||||||
@@ -829,25 +829,6 @@ impl RouterTrait for Router {
|
|||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn health(&self, _req: Request<Body>) -> Response {
|
|
||||||
let workers = self.worker_registry.get_all();
|
|
||||||
let unhealthy_servers: Vec<_> = workers
|
|
||||||
.iter()
|
|
||||||
.filter(|w| !w.is_healthy())
|
|
||||||
.map(|w| w.url().to_string())
|
|
||||||
.collect();
|
|
||||||
|
|
||||||
if unhealthy_servers.is_empty() {
|
|
||||||
(StatusCode::OK, "All servers healthy").into_response()
|
|
||||||
} else {
|
|
||||||
(
|
|
||||||
StatusCode::SERVICE_UNAVAILABLE,
|
|
||||||
format!("Unhealthy servers: {:?}", unhealthy_servers),
|
|
||||||
)
|
|
||||||
.into_response()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn health_generate(&self, req: Request<Body>) -> Response {
|
async fn health_generate(&self, req: Request<Body>) -> Response {
|
||||||
self.proxy_get_request(req, "health_generate").await
|
self.proxy_get_request(req, "health_generate").await
|
||||||
}
|
}
|
||||||
@@ -972,68 +953,6 @@ impl RouterTrait for Router {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn flush_cache(&self) -> Response {
|
|
||||||
// Get all workers
|
|
||||||
let workers = self.worker_registry.get_all();
|
|
||||||
let worker_urls: Vec<String> = workers.iter().map(|w| w.url().to_string()).collect();
|
|
||||||
|
|
||||||
// Send requests to all workers concurrently without headers
|
|
||||||
let mut tasks = Vec::new();
|
|
||||||
for worker_url in &worker_urls {
|
|
||||||
// Get the worker's API key if available
|
|
||||||
let api_key = self
|
|
||||||
.worker_registry
|
|
||||||
.get_by_url(worker_url)
|
|
||||||
.and_then(|w| w.api_key().clone());
|
|
||||||
|
|
||||||
let worker_url = if self.dp_aware {
|
|
||||||
// Need to extract the URL from "http://host:port@dp_rank"
|
|
||||||
let (worker_url_prefix, _dp_rank) = match Self::extract_dp_rank(worker_url) {
|
|
||||||
Ok(tup) => tup,
|
|
||||||
Err(e) => {
|
|
||||||
error!("Failed to extract dp_rank: {}", e);
|
|
||||||
return (
|
|
||||||
StatusCode::INTERNAL_SERVER_ERROR,
|
|
||||||
format!("Failed to extract dp_rank: {}", e),
|
|
||||||
)
|
|
||||||
.into_response();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
worker_url_prefix
|
|
||||||
} else {
|
|
||||||
worker_url
|
|
||||||
};
|
|
||||||
let mut request_builder = self.client.post(format!("{}/flush_cache", worker_url));
|
|
||||||
|
|
||||||
if let Some(key) = api_key {
|
|
||||||
request_builder =
|
|
||||||
request_builder.header("Authorization", format!("Bearer {}", key));
|
|
||||||
}
|
|
||||||
|
|
||||||
tasks.push(request_builder.send());
|
|
||||||
}
|
|
||||||
|
|
||||||
// Wait for all responses
|
|
||||||
let results = futures_util::future::join_all(tasks).await;
|
|
||||||
|
|
||||||
// Check if all succeeded
|
|
||||||
let all_success = results.iter().all(|r| {
|
|
||||||
r.as_ref()
|
|
||||||
.map(|res| res.status().is_success())
|
|
||||||
.unwrap_or(false)
|
|
||||||
});
|
|
||||||
|
|
||||||
if all_success {
|
|
||||||
(StatusCode::OK, "Cache flushed on all servers").into_response()
|
|
||||||
} else {
|
|
||||||
(
|
|
||||||
StatusCode::INTERNAL_SERVER_ERROR,
|
|
||||||
"Cache flush failed on one or more servers",
|
|
||||||
)
|
|
||||||
.into_response()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn get_worker_loads(&self) -> Response {
|
async fn get_worker_loads(&self) -> Response {
|
||||||
let urls_with_key = self.worker_registry.get_all_urls_with_api_key();
|
let urls_with_key = self.worker_registry.get_all_urls_with_api_key();
|
||||||
let mut loads = Vec::new();
|
let mut loads = Vec::new();
|
||||||
@@ -1056,32 +975,6 @@ impl RouterTrait for Router {
|
|||||||
fn router_type(&self) -> &'static str {
|
fn router_type(&self) -> &'static str {
|
||||||
"regular"
|
"regular"
|
||||||
}
|
}
|
||||||
|
|
||||||
fn readiness(&self) -> Response {
|
|
||||||
// Regular router is ready if it has at least one healthy worker
|
|
||||||
let workers = self.worker_registry.get_all();
|
|
||||||
let healthy_count = workers.iter().filter(|w| w.is_healthy()).count();
|
|
||||||
let total_workers = workers.len();
|
|
||||||
|
|
||||||
if healthy_count > 0 {
|
|
||||||
Json(serde_json::json!({
|
|
||||||
"status": "ready",
|
|
||||||
"healthy_workers": healthy_count,
|
|
||||||
"total_workers": total_workers
|
|
||||||
}))
|
|
||||||
.into_response()
|
|
||||||
} else {
|
|
||||||
(
|
|
||||||
StatusCode::SERVICE_UNAVAILABLE,
|
|
||||||
Json(serde_json::json!({
|
|
||||||
"status": "not_ready",
|
|
||||||
"reason": "no healthy workers available",
|
|
||||||
"total_workers": total_workers
|
|
||||||
})),
|
|
||||||
)
|
|
||||||
.into_response()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
|
|||||||
@@ -34,9 +34,6 @@ pub trait RouterTrait: Send + Sync + Debug {
|
|||||||
/// Get a reference to self as Any for downcasting
|
/// Get a reference to self as Any for downcasting
|
||||||
fn as_any(&self) -> &dyn std::any::Any;
|
fn as_any(&self) -> &dyn std::any::Any;
|
||||||
|
|
||||||
/// Route a health check request
|
|
||||||
async fn health(&self, req: Request<Body>) -> Response;
|
|
||||||
|
|
||||||
/// Route a health generate request
|
/// Route a health generate request
|
||||||
async fn health_generate(&self, req: Request<Body>) -> Response;
|
async fn health_generate(&self, req: Request<Body>) -> Response;
|
||||||
|
|
||||||
@@ -129,9 +126,6 @@ pub trait RouterTrait: Send + Sync + Debug {
|
|||||||
model_id: Option<&str>,
|
model_id: Option<&str>,
|
||||||
) -> Response;
|
) -> Response;
|
||||||
|
|
||||||
/// Flush cache on all workers
|
|
||||||
async fn flush_cache(&self) -> Response;
|
|
||||||
|
|
||||||
/// Get worker loads (for monitoring)
|
/// Get worker loads (for monitoring)
|
||||||
async fn get_worker_loads(&self) -> Response;
|
async fn get_worker_loads(&self) -> Response;
|
||||||
|
|
||||||
@@ -142,13 +136,4 @@ pub trait RouterTrait: Send + Sync + Debug {
|
|||||||
fn is_pd_mode(&self) -> bool {
|
fn is_pd_mode(&self) -> bool {
|
||||||
self.router_type() == "pd"
|
self.router_type() == "pd"
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Server liveness check - is the server process running
|
|
||||||
fn liveness(&self) -> Response {
|
|
||||||
// Simple liveness check - if we can respond, we're alive
|
|
||||||
(StatusCode::OK, "OK").into_response()
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Server readiness check - is the server ready to handle requests
|
|
||||||
fn readiness(&self) -> Response;
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -289,10 +289,6 @@ impl RouterTrait for RouterManager {
|
|||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn health(&self, _req: Request<Body>) -> Response {
|
|
||||||
(StatusCode::OK, "RouterManager is healthy").into_response()
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn health_generate(&self, _req: Request<Body>) -> Response {
|
async fn health_generate(&self, _req: Request<Body>) -> Response {
|
||||||
// TODO: Should check if any router has healthy workers
|
// TODO: Should check if any router has healthy workers
|
||||||
(
|
(
|
||||||
@@ -512,16 +508,6 @@ impl RouterTrait for RouterManager {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn flush_cache(&self) -> Response {
|
|
||||||
// TODO: Call flush_cache on all routers that have workers
|
|
||||||
if self.routers.is_empty() {
|
|
||||||
(StatusCode::SERVICE_UNAVAILABLE, "No routers configured").into_response()
|
|
||||||
} else {
|
|
||||||
// TODO: Actually flush cache on all routers
|
|
||||||
(StatusCode::OK, "Cache flush requested").into_response()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn get_worker_loads(&self) -> Response {
|
async fn get_worker_loads(&self) -> Response {
|
||||||
let workers = self.worker_registry.get_all();
|
let workers = self.worker_registry.get_all();
|
||||||
let loads: Vec<serde_json::Value> = workers
|
let loads: Vec<serde_json::Value> = workers
|
||||||
@@ -549,15 +535,6 @@ impl RouterTrait for RouterManager {
|
|||||||
fn router_type(&self) -> &'static str {
|
fn router_type(&self) -> &'static str {
|
||||||
"manager"
|
"manager"
|
||||||
}
|
}
|
||||||
|
|
||||||
fn readiness(&self) -> Response {
|
|
||||||
if self.routers.is_empty() {
|
|
||||||
(StatusCode::SERVICE_UNAVAILABLE, "No routers configured").into_response()
|
|
||||||
} else {
|
|
||||||
// TODO: Check readiness of all routers
|
|
||||||
(StatusCode::OK, "Ready").into_response()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl std::fmt::Debug for RouterManager {
|
impl std::fmt::Debug for RouterManager {
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
use crate::{
|
use crate::{
|
||||||
config::{ConnectionMode, HistoryBackend, RouterConfig},
|
config::{ConnectionMode, HistoryBackend, RouterConfig, RoutingMode},
|
||||||
core::{WorkerManager, WorkerRegistry, WorkerType},
|
core::{WorkerManager, WorkerRegistry, WorkerType},
|
||||||
data_connector::{MemoryResponseStorage, NoOpResponseStorage, SharedResponseStorage},
|
data_connector::{MemoryResponseStorage, NoOpResponseStorage, SharedResponseStorage},
|
||||||
logging::{self, LoggingConfig},
|
logging::{self, LoggingConfig},
|
||||||
@@ -121,16 +121,56 @@ async fn sink_handler() -> Response {
|
|||||||
StatusCode::NOT_FOUND.into_response()
|
StatusCode::NOT_FOUND.into_response()
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn liveness(State(state): State<Arc<AppState>>) -> Response {
|
async fn liveness() -> Response {
|
||||||
state.router.liveness()
|
(StatusCode::OK, "OK").into_response()
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn readiness(State(state): State<Arc<AppState>>) -> Response {
|
async fn readiness(State(state): State<Arc<AppState>>) -> Response {
|
||||||
state.router.readiness()
|
let workers = state.context.worker_registry.get_all();
|
||||||
|
let healthy_workers: Vec<_> = workers.iter().filter(|w| w.is_healthy()).collect();
|
||||||
|
|
||||||
|
let is_ready = if state.context.router_config.enable_igw {
|
||||||
|
!healthy_workers.is_empty()
|
||||||
|
} else {
|
||||||
|
match &state.context.router_config.mode {
|
||||||
|
RoutingMode::PrefillDecode { .. } => {
|
||||||
|
let has_prefill = healthy_workers
|
||||||
|
.iter()
|
||||||
|
.any(|w| matches!(w.worker_type(), WorkerType::Prefill { .. }));
|
||||||
|
let has_decode = healthy_workers
|
||||||
|
.iter()
|
||||||
|
.any(|w| matches!(w.worker_type(), WorkerType::Decode));
|
||||||
|
has_prefill && has_decode
|
||||||
|
}
|
||||||
|
RoutingMode::Regular { .. } => !healthy_workers.is_empty(),
|
||||||
|
RoutingMode::OpenAI { .. } => !healthy_workers.is_empty(),
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
if is_ready {
|
||||||
|
(
|
||||||
|
StatusCode::OK,
|
||||||
|
Json(json!({
|
||||||
|
"status": "ready",
|
||||||
|
"healthy_workers": healthy_workers.len(),
|
||||||
|
"total_workers": workers.len()
|
||||||
|
})),
|
||||||
|
)
|
||||||
|
.into_response()
|
||||||
|
} else {
|
||||||
|
(
|
||||||
|
StatusCode::SERVICE_UNAVAILABLE,
|
||||||
|
Json(json!({
|
||||||
|
"status": "not ready",
|
||||||
|
"reason": "insufficient healthy workers"
|
||||||
|
})),
|
||||||
|
)
|
||||||
|
.into_response()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn health(State(state): State<Arc<AppState>>, req: Request) -> Response {
|
async fn health(_state: State<Arc<AppState>>) -> Response {
|
||||||
state.router.health(req).await
|
liveness().await
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn health_generate(State(state): State<Arc<AppState>>, req: Request) -> Response {
|
async fn health_generate(State(state): State<Arc<AppState>>, req: Request) -> Response {
|
||||||
@@ -311,7 +351,52 @@ async fn remove_worker(
|
|||||||
}
|
}
|
||||||
|
|
||||||
async fn flush_cache(State(state): State<Arc<AppState>>, _req: Request) -> Response {
|
async fn flush_cache(State(state): State<Arc<AppState>>, _req: Request) -> Response {
|
||||||
state.router.flush_cache().await
|
match WorkerManager::flush_cache_all(&state.context.worker_registry, &state.context.client)
|
||||||
|
.await
|
||||||
|
{
|
||||||
|
Ok(result) => {
|
||||||
|
if result.failed.is_empty() {
|
||||||
|
(
|
||||||
|
StatusCode::OK,
|
||||||
|
Json(json!({
|
||||||
|
"status": "success",
|
||||||
|
"message": result.message,
|
||||||
|
"workers_flushed": result.successful.len(),
|
||||||
|
"total_http_workers": result.http_workers,
|
||||||
|
"total_workers": result.total_workers
|
||||||
|
})),
|
||||||
|
)
|
||||||
|
.into_response()
|
||||||
|
} else {
|
||||||
|
(
|
||||||
|
StatusCode::PARTIAL_CONTENT,
|
||||||
|
Json(json!({
|
||||||
|
"status": "partial_success",
|
||||||
|
"message": result.message,
|
||||||
|
"successful": result.successful,
|
||||||
|
"failed": result.failed.into_iter().map(|(url, err)| json!({
|
||||||
|
"worker": url,
|
||||||
|
"error": err
|
||||||
|
})).collect::<Vec<_>>(),
|
||||||
|
"total_http_workers": result.http_workers,
|
||||||
|
"total_workers": result.total_workers
|
||||||
|
})),
|
||||||
|
)
|
||||||
|
.into_response()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
error!("Failed to flush cache: {}", e);
|
||||||
|
(
|
||||||
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
|
Json(json!({
|
||||||
|
"status": "error",
|
||||||
|
"message": format!("Failed to flush cache: {}", e)
|
||||||
|
})),
|
||||||
|
)
|
||||||
|
.into_response()
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn get_loads(State(state): State<Arc<AppState>>, _req: Request) -> Response {
|
async fn get_loads(State(state): State<Arc<AppState>>, _req: Request) -> Response {
|
||||||
|
|||||||
@@ -239,13 +239,6 @@ mod health_tests {
|
|||||||
let resp = app.oneshot(req).await.unwrap();
|
let resp = app.oneshot(req).await.unwrap();
|
||||||
assert_eq!(resp.status(), StatusCode::OK);
|
assert_eq!(resp.status(), StatusCode::OK);
|
||||||
|
|
||||||
// The health endpoint returns plain text, not JSON
|
|
||||||
let body = axum::body::to_bytes(resp.into_body(), usize::MAX)
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
let body_str = String::from_utf8_lossy(&body);
|
|
||||||
assert!(body_str.contains("All servers healthy"));
|
|
||||||
|
|
||||||
ctx.shutdown().await;
|
ctx.shutdown().await;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -101,27 +101,6 @@ async fn test_openai_router_creation() {
|
|||||||
assert!(!router.is_pd_mode());
|
assert!(!router.is_pd_mode());
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Test health endpoints
|
|
||||||
#[tokio::test]
|
|
||||||
async fn test_openai_router_health() {
|
|
||||||
let router = OpenAIRouter::new(
|
|
||||||
"https://api.openai.com".to_string(),
|
|
||||||
None,
|
|
||||||
Arc::new(MemoryResponseStorage::new()),
|
|
||||||
)
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
let req = Request::builder()
|
|
||||||
.method(Method::GET)
|
|
||||||
.uri("/health")
|
|
||||||
.body(Body::empty())
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
let response = router.health(req).await;
|
|
||||||
assert_eq!(response.status(), StatusCode::OK);
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Test server info endpoint
|
/// Test server info endpoint
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_openai_router_server_info() {
|
async fn test_openai_router_server_info() {
|
||||||
|
|||||||
Reference in New Issue
Block a user