[router] migrate router from actix to axum (#8479)
This commit is contained in:
@@ -1,10 +1,17 @@
|
||||
//! Router implementations
|
||||
|
||||
use actix_web::{HttpRequest, HttpResponse};
|
||||
use async_trait::async_trait;
|
||||
use axum::{
|
||||
body::Body,
|
||||
extract::Request,
|
||||
http::{HeaderMap, StatusCode},
|
||||
response::{IntoResponse, Response},
|
||||
};
|
||||
use reqwest::Client;
|
||||
use std::fmt::Debug;
|
||||
|
||||
use crate::openai_api_types::{ChatCompletionRequest, CompletionRequest, GenerateRequest};
|
||||
|
||||
pub mod factory;
|
||||
pub mod pd_router;
|
||||
pub mod pd_types;
|
||||
@@ -33,54 +40,55 @@ pub trait WorkerManagement: Send + Sync {
|
||||
///
|
||||
/// This trait provides a unified interface for routing requests,
|
||||
/// regardless of whether it's a regular router or PD router.
|
||||
#[async_trait(?Send)]
|
||||
#[async_trait]
|
||||
pub trait RouterTrait: Send + Sync + Debug + WorkerManagement {
|
||||
/// Get a reference to self as Any for downcasting
|
||||
fn as_any(&self) -> &dyn std::any::Any;
|
||||
|
||||
/// Route a health check request
|
||||
async fn health(&self, client: &Client, req: &HttpRequest) -> HttpResponse;
|
||||
async fn health(&self, client: &Client, req: Request<Body>) -> Response;
|
||||
|
||||
/// Route a health generate request
|
||||
async fn health_generate(&self, client: &Client, req: &HttpRequest) -> HttpResponse;
|
||||
async fn health_generate(&self, client: &Client, req: Request<Body>) -> Response;
|
||||
|
||||
/// Get server information
|
||||
async fn get_server_info(&self, client: &Client, req: &HttpRequest) -> HttpResponse;
|
||||
async fn get_server_info(&self, client: &Client, req: Request<Body>) -> Response;
|
||||
|
||||
/// Get available models
|
||||
async fn get_models(&self, client: &Client, req: &HttpRequest) -> HttpResponse;
|
||||
async fn get_models(&self, client: &Client, req: Request<Body>) -> Response;
|
||||
|
||||
/// Get model information
|
||||
async fn get_model_info(&self, client: &Client, req: &HttpRequest) -> HttpResponse;
|
||||
async fn get_model_info(&self, client: &Client, req: Request<Body>) -> Response;
|
||||
|
||||
/// Route a generate request
|
||||
async fn route_generate(
|
||||
&self,
|
||||
client: &Client,
|
||||
req: &HttpRequest,
|
||||
body: serde_json::Value,
|
||||
) -> HttpResponse;
|
||||
headers: Option<&HeaderMap>,
|
||||
body: &GenerateRequest,
|
||||
) -> Response;
|
||||
|
||||
/// Route a chat completion request
|
||||
async fn route_chat(
|
||||
&self,
|
||||
client: &Client,
|
||||
req: &HttpRequest,
|
||||
body: serde_json::Value,
|
||||
) -> HttpResponse;
|
||||
headers: Option<&HeaderMap>,
|
||||
body: &ChatCompletionRequest,
|
||||
) -> Response;
|
||||
|
||||
/// Route a completion request
|
||||
async fn route_completion(
|
||||
&self,
|
||||
client: &Client,
|
||||
req: &HttpRequest,
|
||||
body: serde_json::Value,
|
||||
) -> HttpResponse;
|
||||
headers: Option<&HeaderMap>,
|
||||
body: &CompletionRequest,
|
||||
) -> Response;
|
||||
|
||||
/// Flush cache on all workers
|
||||
async fn flush_cache(&self, client: &Client) -> HttpResponse;
|
||||
async fn flush_cache(&self, client: &Client) -> Response;
|
||||
|
||||
/// Get worker loads (for monitoring)
|
||||
async fn get_worker_loads(&self, client: &Client) -> HttpResponse;
|
||||
async fn get_worker_loads(&self, client: &Client) -> Response;
|
||||
|
||||
/// Get router type name
|
||||
fn router_type(&self) -> &'static str;
|
||||
@@ -91,11 +99,11 @@ pub trait RouterTrait: Send + Sync + Debug + WorkerManagement {
|
||||
}
|
||||
|
||||
/// Server liveness check - is the server process running
|
||||
fn liveness(&self) -> HttpResponse {
|
||||
fn liveness(&self) -> Response {
|
||||
// Simple liveness check - if we can respond, we're alive
|
||||
HttpResponse::Ok().body("OK")
|
||||
(StatusCode::OK, "OK").into_response()
|
||||
}
|
||||
|
||||
/// Server readiness check - is the server ready to handle requests
|
||||
fn readiness(&self) -> HttpResponse;
|
||||
fn readiness(&self) -> Response;
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,17 +1,23 @@
|
||||
use crate::core::{HealthChecker, Worker, WorkerFactory};
|
||||
use crate::metrics::RouterMetrics;
|
||||
use crate::middleware::get_request_id;
|
||||
use crate::openai_api_types::{ChatCompletionRequest, CompletionRequest, GenerateRequest};
|
||||
use crate::policies::LoadBalancingPolicy;
|
||||
use actix_web::http::header::{HeaderValue, CONTENT_TYPE};
|
||||
use actix_web::{HttpRequest, HttpResponse};
|
||||
use futures_util::{StreamExt, TryStreamExt};
|
||||
use crate::routers::{RouterTrait, WorkerManagement};
|
||||
use axum::{
|
||||
body::Body,
|
||||
extract::Request,
|
||||
http::{header::CONTENT_TYPE, HeaderMap, HeaderValue, StatusCode},
|
||||
response::{IntoResponse, Response},
|
||||
Json,
|
||||
};
|
||||
use futures_util::StreamExt;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::{Arc, RwLock};
|
||||
use std::thread;
|
||||
use std::time::{Duration, Instant};
|
||||
use tokio_stream::wrappers::UnboundedReceiverStream;
|
||||
use tracing::{debug, error, info, warn};
|
||||
|
||||
pub fn copy_request_headers(req: &HttpRequest) -> Vec<(String, String)> {
|
||||
pub fn copy_request_headers(req: &Request<Body>) -> Vec<(String, String)> {
|
||||
req.headers()
|
||||
.iter()
|
||||
.filter_map(|(name, value)| {
|
||||
@@ -239,154 +245,107 @@ impl Router {
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn send_request(
|
||||
&self,
|
||||
client: &reqwest::Client,
|
||||
worker_url: &str,
|
||||
route: &str,
|
||||
req: &HttpRequest,
|
||||
) -> HttpResponse {
|
||||
let request_id = get_request_id(req);
|
||||
let start = Instant::now();
|
||||
|
||||
let worker_url = if self.dp_aware {
|
||||
pub async fn send_health_check(&self, client: &Client, worker_url: &str) -> Response {
|
||||
let health_url = if self.dp_aware {
|
||||
// Need to extract the URL from "http://host:port@dp_rank"
|
||||
let (worker_url_prefix, _dp_rank) = match Self::extract_dp_rank(worker_url) {
|
||||
Ok(tup) => tup,
|
||||
match Self::extract_dp_rank(worker_url) {
|
||||
Ok((worker_url_prefix, _dp_rank)) => worker_url_prefix,
|
||||
Err(e) => {
|
||||
error!("Failed to extract dp_rank: {}", e);
|
||||
return HttpResponse::InternalServerError().finish();
|
||||
error!("Failed to extract dp_rank for health check: {}", e);
|
||||
return (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Failed to extract dp_rank: {}", e),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
};
|
||||
worker_url_prefix
|
||||
}
|
||||
} else {
|
||||
worker_url
|
||||
};
|
||||
|
||||
let mut request_builder = client.get(format!("{}{}", worker_url, route));
|
||||
|
||||
// Copy all headers from original request except for /health because it does not need authorization
|
||||
if route != "/health" {
|
||||
for (name, value) in copy_request_headers(req) {
|
||||
// Skip Content-Type and Content-Length as .json() sets them
|
||||
if name.to_lowercase() != "content-type" && name.to_lowercase() != "content-length"
|
||||
{
|
||||
request_builder = request_builder.header(name, value);
|
||||
}
|
||||
}
|
||||
}
|
||||
let request_builder = client.get(format!("{}/health", health_url));
|
||||
|
||||
let response = match request_builder.send().await {
|
||||
Ok(res) => {
|
||||
let status = actix_web::http::StatusCode::from_u16(res.status().as_u16())
|
||||
.unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR);
|
||||
let status = StatusCode::from_u16(res.status().as_u16())
|
||||
.unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
|
||||
|
||||
match res.bytes().await {
|
||||
Ok(body) => HttpResponse::build(status).body(body.to_vec()),
|
||||
Ok(body) => (status, body).into_response(),
|
||||
Err(e) => {
|
||||
error!(
|
||||
request_id = %request_id,
|
||||
worker_url = %worker_url,
|
||||
route = %route,
|
||||
worker_url = %health_url,
|
||||
error = %e,
|
||||
"Failed to read response body"
|
||||
"Failed to read health response body"
|
||||
);
|
||||
HttpResponse::InternalServerError()
|
||||
.body(format!("Failed to read response body: {}", e))
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Failed to read response body: {}", e),
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
error!(
|
||||
request_id = %request_id,
|
||||
worker_url = %worker_url,
|
||||
route = %route,
|
||||
worker_url = %health_url,
|
||||
error = %e,
|
||||
"Failed to send request to worker"
|
||||
"Failed to send health request to worker"
|
||||
);
|
||||
HttpResponse::InternalServerError().body(format!(
|
||||
"Failed to send request to worker {}: {}",
|
||||
worker_url, e
|
||||
))
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Failed to send request to worker {}: {}", health_url, e),
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
};
|
||||
|
||||
// Record request metrics
|
||||
if route != "/health" {
|
||||
let duration = start.elapsed();
|
||||
RouterMetrics::record_request(route);
|
||||
RouterMetrics::record_request_duration(route, duration);
|
||||
|
||||
if !response.status().is_success() {
|
||||
RouterMetrics::record_request_error(route, "request_failed");
|
||||
}
|
||||
}
|
||||
// Don't record metrics for health checks
|
||||
response
|
||||
}
|
||||
|
||||
pub async fn route_to_first(
|
||||
// Helper method to proxy GET requests to the first available worker
|
||||
async fn proxy_get_request(
|
||||
&self,
|
||||
client: &reqwest::Client,
|
||||
route: &str,
|
||||
req: &HttpRequest,
|
||||
) -> HttpResponse {
|
||||
let request_id = get_request_id(req);
|
||||
const MAX_REQUEST_RETRIES: u32 = 3;
|
||||
const MAX_TOTAL_RETRIES: u32 = 6;
|
||||
let mut total_retries = 0;
|
||||
client: &Client,
|
||||
req: Request<Body>,
|
||||
endpoint: &str,
|
||||
) -> Response {
|
||||
let headers = copy_request_headers(&req);
|
||||
|
||||
while total_retries < MAX_TOTAL_RETRIES {
|
||||
match self.select_first_worker() {
|
||||
Ok(worker_url) => {
|
||||
let mut request_retries = 0;
|
||||
|
||||
// Try the same worker multiple times
|
||||
while request_retries < MAX_REQUEST_RETRIES {
|
||||
if total_retries >= 1 {
|
||||
info!("Retrying request after {} failed attempts", total_retries);
|
||||
}
|
||||
|
||||
let response = self.send_request(client, &worker_url, route, req).await;
|
||||
|
||||
if response.status().is_success() {
|
||||
return response;
|
||||
} else {
|
||||
// if the worker is healthy, it means the request is bad, so return the error response
|
||||
let health_response =
|
||||
self.send_request(client, &worker_url, "/health", req).await;
|
||||
if health_response.status().is_success() {
|
||||
return response;
|
||||
}
|
||||
}
|
||||
|
||||
warn!(
|
||||
request_id = %request_id,
|
||||
route = %route,
|
||||
worker_url = %worker_url,
|
||||
attempt = request_retries + 1,
|
||||
max_attempts = MAX_REQUEST_RETRIES,
|
||||
"Request failed"
|
||||
);
|
||||
|
||||
request_retries += 1;
|
||||
total_retries += 1;
|
||||
|
||||
if request_retries == MAX_REQUEST_RETRIES {
|
||||
warn!(
|
||||
request_id = %request_id,
|
||||
worker_url = %worker_url,
|
||||
"Removing failed worker"
|
||||
);
|
||||
self.remove_failed_worker(&worker_url);
|
||||
break;
|
||||
}
|
||||
match self.select_first_worker() {
|
||||
Ok(worker_url) => {
|
||||
let mut request_builder = client.get(format!("{}/{}", worker_url, endpoint));
|
||||
for (name, value) in headers {
|
||||
if name.to_lowercase() != "content-type"
|
||||
&& name.to_lowercase() != "content-length"
|
||||
{
|
||||
request_builder = request_builder.header(name, value);
|
||||
}
|
||||
}
|
||||
Err(e) => return HttpResponse::InternalServerError().body(e),
|
||||
}
|
||||
}
|
||||
|
||||
HttpResponse::InternalServerError().body("All retry attempts failed")
|
||||
match request_builder.send().await {
|
||||
Ok(res) => {
|
||||
let status = StatusCode::from_u16(res.status().as_u16())
|
||||
.unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
|
||||
match res.bytes().await {
|
||||
Ok(body) => (status, body).into_response(),
|
||||
Err(e) => (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Failed to read response: {}", e),
|
||||
)
|
||||
.into_response(),
|
||||
}
|
||||
}
|
||||
Err(e) => (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Request failed: {}", e),
|
||||
)
|
||||
.into_response(),
|
||||
}
|
||||
}
|
||||
Err(e) => (StatusCode::SERVICE_UNAVAILABLE, e).into_response(),
|
||||
}
|
||||
}
|
||||
|
||||
// New method to route typed requests directly
|
||||
@@ -395,11 +354,10 @@ impl Router {
|
||||
>(
|
||||
&self,
|
||||
client: &reqwest::Client,
|
||||
req: &HttpRequest,
|
||||
headers: Option<&HeaderMap>,
|
||||
typed_req: &T,
|
||||
route: &str,
|
||||
) -> HttpResponse {
|
||||
let request_id = get_request_id(req);
|
||||
) -> Response {
|
||||
// Handle retries like the original implementation
|
||||
let start = Instant::now();
|
||||
const MAX_REQUEST_RETRIES: u32 = 3;
|
||||
@@ -440,7 +398,7 @@ impl Router {
|
||||
let response = self
|
||||
.send_typed_request(
|
||||
client,
|
||||
req,
|
||||
headers,
|
||||
typed_req,
|
||||
route,
|
||||
&worker_url,
|
||||
@@ -455,8 +413,7 @@ impl Router {
|
||||
return response;
|
||||
} else {
|
||||
// if the worker is healthy, it means the request is bad, so return the error response
|
||||
let health_response =
|
||||
self.send_request(client, &worker_url, "/health", req).await;
|
||||
let health_response = self.send_health_check(client, &worker_url).await;
|
||||
if health_response.status().is_success() {
|
||||
RouterMetrics::record_request_error(route, "request_failed");
|
||||
return response;
|
||||
@@ -464,9 +421,11 @@ impl Router {
|
||||
}
|
||||
|
||||
warn!(
|
||||
request_id = %request_id,
|
||||
"Generate request failed route={} worker_url={} attempt={} max_attempts={}",
|
||||
route, worker_url, request_retries + 1, MAX_REQUEST_RETRIES
|
||||
route,
|
||||
worker_url,
|
||||
request_retries + 1,
|
||||
MAX_REQUEST_RETRIES
|
||||
);
|
||||
|
||||
request_retries += 1;
|
||||
@@ -474,17 +433,21 @@ impl Router {
|
||||
|
||||
if request_retries == MAX_REQUEST_RETRIES {
|
||||
warn!(
|
||||
request_id = %request_id,
|
||||
"Removing failed worker after typed request failures worker_url={}", worker_url
|
||||
"Removing failed worker after typed request failures worker_url={}",
|
||||
worker_url
|
||||
);
|
||||
self.remove_failed_worker(&worker_url);
|
||||
self.remove_worker(&worker_url);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
RouterMetrics::record_request_error(route, "request_failed");
|
||||
HttpResponse::InternalServerError().body("All retry attempts failed")
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
"All retry attempts failed",
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
|
||||
// Helper method to select worker from text using the policy
|
||||
@@ -521,14 +484,13 @@ impl Router {
|
||||
async fn send_typed_request<T: serde::Serialize>(
|
||||
&self,
|
||||
client: &reqwest::Client,
|
||||
req: &HttpRequest,
|
||||
headers: Option<&HeaderMap>,
|
||||
typed_req: &T,
|
||||
route: &str,
|
||||
worker_url: &str,
|
||||
is_stream: bool,
|
||||
load_incremented: bool, // Whether load was incremented for this request
|
||||
) -> HttpResponse {
|
||||
let request_id = get_request_id(req);
|
||||
) -> Response {
|
||||
let start = Instant::now();
|
||||
|
||||
let mut request_builder = if self.dp_aware {
|
||||
@@ -536,7 +498,11 @@ impl Router {
|
||||
Ok(tup) => tup,
|
||||
Err(e) => {
|
||||
error!("Failed to extract dp_rank: {}", e);
|
||||
return HttpResponse::InternalServerError().finish();
|
||||
return (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Failed to extract dp_rank: {}", e),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
};
|
||||
|
||||
@@ -544,8 +510,11 @@ impl Router {
|
||||
let mut json_val = match serde_json::to_value(typed_req) {
|
||||
Ok(j) => j,
|
||||
Err(e) => {
|
||||
return HttpResponse::BadRequest()
|
||||
.body(format!("Convert into serde_json::Value failed: {}", e));
|
||||
return (
|
||||
StatusCode::BAD_REQUEST,
|
||||
format!("Convert into serde_json::Value failed: {}", e),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
};
|
||||
|
||||
@@ -560,8 +529,11 @@ impl Router {
|
||||
serde_json::to_string(&json_val).unwrap_or(String::from("ERR"))
|
||||
);
|
||||
} else {
|
||||
return HttpResponse::BadRequest()
|
||||
.body("Failed to insert the data_parallel_rank field into the request body");
|
||||
return (
|
||||
StatusCode::BAD_REQUEST,
|
||||
"Failed to insert the data_parallel_rank field into the request body",
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
|
||||
client
|
||||
@@ -573,11 +545,15 @@ impl Router {
|
||||
.json(typed_req) // Use json() directly with typed request
|
||||
};
|
||||
|
||||
// Copy all headers from original request
|
||||
for (name, value) in copy_request_headers(req) {
|
||||
// Skip Content-Type and Content-Length as .json() sets them
|
||||
if name.to_lowercase() != "content-type" && name.to_lowercase() != "content-length" {
|
||||
request_builder = request_builder.header(&name, &value);
|
||||
// Copy all headers from original request if provided
|
||||
if let Some(headers) = headers {
|
||||
for (name, value) in headers {
|
||||
// Skip Content-Type and Content-Length as .json() sets them
|
||||
if name.to_string().to_lowercase() != "content-type"
|
||||
&& name.to_string().to_lowercase() != "content-length"
|
||||
{
|
||||
request_builder = request_builder.header(name, value);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -585,7 +561,6 @@ impl Router {
|
||||
Ok(res) => res,
|
||||
Err(e) => {
|
||||
error!(
|
||||
request_id = %request_id,
|
||||
"Failed to send typed request worker_url={} route={} error={}",
|
||||
worker_url, route, e
|
||||
);
|
||||
@@ -600,20 +575,24 @@ impl Router {
|
||||
}
|
||||
}
|
||||
|
||||
return HttpResponse::InternalServerError().body(format!("Request failed: {}", e));
|
||||
return (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Request failed: {}", e),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
};
|
||||
|
||||
let status = actix_web::http::StatusCode::from_u16(res.status().as_u16())
|
||||
.unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR);
|
||||
let status = StatusCode::from_u16(res.status().as_u16())
|
||||
.unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
|
||||
|
||||
if !is_stream {
|
||||
// For non-streaming requests, get response first
|
||||
let response = match res.bytes().await {
|
||||
Ok(body) => HttpResponse::build(status).body(body.to_vec()),
|
||||
Ok(body) => (status, body).into_response(),
|
||||
Err(e) => {
|
||||
let error_msg = format!("Failed to get response body: {}", e);
|
||||
HttpResponse::InternalServerError().body(error_msg)
|
||||
(StatusCode::INTERNAL_SERVER_ERROR, error_msg).into_response()
|
||||
}
|
||||
};
|
||||
|
||||
@@ -638,42 +617,86 @@ impl Router {
|
||||
let workers = Arc::clone(&self.workers);
|
||||
let worker_url = worker_url.to_string();
|
||||
|
||||
HttpResponse::build(status)
|
||||
.insert_header((CONTENT_TYPE, HeaderValue::from_static("text/event-stream")))
|
||||
.streaming(
|
||||
res.bytes_stream()
|
||||
.map_err(|_| {
|
||||
actix_web::error::ErrorInternalServerError("Failed to read stream")
|
||||
})
|
||||
.inspect(move |bytes| {
|
||||
if let Ok(bytes) = bytes {
|
||||
if bytes
|
||||
.as_ref()
|
||||
.windows(12)
|
||||
.any(|window| window == b"data: [DONE]")
|
||||
{
|
||||
if let Ok(workers_guard) = workers.read() {
|
||||
if let Some(worker) =
|
||||
workers_guard.iter().find(|w| w.url() == &worker_url)
|
||||
{
|
||||
worker.decrement_load();
|
||||
RouterMetrics::set_running_requests(
|
||||
&worker_url,
|
||||
worker.load(),
|
||||
);
|
||||
}
|
||||
let stream = res.bytes_stream();
|
||||
let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
|
||||
|
||||
// Spawn task to forward stream and detect completion
|
||||
tokio::spawn(async move {
|
||||
let mut stream = stream;
|
||||
while let Some(chunk) = stream.next().await {
|
||||
match chunk {
|
||||
Ok(bytes) => {
|
||||
// Check for stream end marker
|
||||
if bytes
|
||||
.as_ref()
|
||||
.windows(12)
|
||||
.any(|window| window == b"data: [DONE]")
|
||||
{
|
||||
if let Ok(workers_guard) = workers.read() {
|
||||
if let Some(worker) =
|
||||
workers_guard.iter().find(|w| w.url() == &worker_url)
|
||||
{
|
||||
worker.decrement_load();
|
||||
RouterMetrics::set_running_requests(
|
||||
&worker_url,
|
||||
worker.load(),
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}),
|
||||
)
|
||||
if tx.send(Ok(bytes)).is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
let _ = tx.send(Err(format!("Stream error: {}", e)));
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let stream = UnboundedReceiverStream::new(rx);
|
||||
let body = Body::from_stream(stream);
|
||||
|
||||
let mut response = Response::new(body);
|
||||
*response.status_mut() = status;
|
||||
response
|
||||
.headers_mut()
|
||||
.insert(CONTENT_TYPE, HeaderValue::from_static("text/event-stream"));
|
||||
response
|
||||
} else {
|
||||
// For requests without load tracking, just stream
|
||||
HttpResponse::build(status)
|
||||
.insert_header((CONTENT_TYPE, HeaderValue::from_static("text/event-stream")))
|
||||
.streaming(res.bytes_stream().map_err(|_| {
|
||||
actix_web::error::ErrorInternalServerError("Failed to read stream")
|
||||
}))
|
||||
let stream = res.bytes_stream();
|
||||
let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
|
||||
|
||||
// Spawn task to forward stream
|
||||
tokio::spawn(async move {
|
||||
let mut stream = stream;
|
||||
while let Some(chunk) = stream.next().await {
|
||||
match chunk {
|
||||
Ok(bytes) => {
|
||||
if tx.send(Ok(bytes)).is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
let _ = tx.send(Err(format!("Stream error: {}", e)));
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let stream = UnboundedReceiverStream::new(rx);
|
||||
let body = Body::from_stream(stream);
|
||||
|
||||
let mut response = Response::new(body);
|
||||
*response.status_mut() = status;
|
||||
response
|
||||
.headers_mut()
|
||||
.insert(CONTENT_TYPE, HeaderValue::from_static("text/event-stream"));
|
||||
response
|
||||
}
|
||||
}
|
||||
|
||||
@@ -775,7 +798,6 @@ impl Router {
|
||||
}
|
||||
}
|
||||
|
||||
/// Remove all the worker(s) that match the URL prefix
|
||||
pub fn remove_worker(&self, worker_url: &str) {
|
||||
if self.dp_aware {
|
||||
// remove dp-aware workers in a prefix-matching fashion
|
||||
@@ -844,28 +866,6 @@ impl Router {
|
||||
}
|
||||
}
|
||||
|
||||
/// Remove a specific failed worker; for internal usage
|
||||
fn remove_failed_worker(&self, worker_url: &str) {
|
||||
let mut workers_guard = self.workers.write().unwrap();
|
||||
if let Some(index) = workers_guard.iter().position(|w| w.url() == worker_url) {
|
||||
workers_guard.remove(index);
|
||||
info!("Removed failed worker: {}", worker_url);
|
||||
RouterMetrics::set_active_workers(workers_guard.len());
|
||||
} else {
|
||||
warn!("Worker {} not found, skipping removal", worker_url);
|
||||
return;
|
||||
}
|
||||
|
||||
// If cache aware policy, remove the worker from the tree
|
||||
if let Some(cache_aware) = self
|
||||
.policy
|
||||
.as_any()
|
||||
.downcast_ref::<crate::policies::CacheAwarePolicy>()
|
||||
{
|
||||
cache_aware.remove_worker(worker_url);
|
||||
}
|
||||
}
|
||||
|
||||
async fn get_worker_load(&self, client: &reqwest::Client, worker_url: &str) -> Option<isize> {
|
||||
let worker_url = if self.dp_aware {
|
||||
// Need to extract the URL from "http://host:port@dp_rank"
|
||||
@@ -1004,7 +1004,6 @@ impl Router {
|
||||
}
|
||||
}
|
||||
|
||||
use crate::routers::{RouterTrait, WorkerManagement};
|
||||
use async_trait::async_trait;
|
||||
use reqwest::Client;
|
||||
|
||||
@@ -1023,100 +1022,78 @@ impl WorkerManagement for Router {
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait(?Send)]
|
||||
#[async_trait]
|
||||
impl RouterTrait for Router {
|
||||
fn as_any(&self) -> &dyn std::any::Any {
|
||||
self
|
||||
}
|
||||
|
||||
async fn health(&self, _client: &Client, _req: &HttpRequest) -> HttpResponse {
|
||||
// Check local health state of all workers (consistent with PD router)
|
||||
// Note: This uses cached health status from background health checks, not live checks
|
||||
let mut all_healthy = true;
|
||||
let mut unhealthy_servers = Vec::new();
|
||||
async fn health(&self, _client: &Client, _req: Request<Body>) -> Response {
|
||||
let workers = self.workers.read().unwrap();
|
||||
let unhealthy_servers: Vec<_> = workers
|
||||
.iter()
|
||||
.filter(|w| !w.is_healthy())
|
||||
.map(|w| w.url().to_string())
|
||||
.collect();
|
||||
|
||||
for worker in self.workers.read().unwrap().iter() {
|
||||
if !worker.is_healthy() {
|
||||
all_healthy = false;
|
||||
unhealthy_servers.push(worker.url().to_string());
|
||||
}
|
||||
}
|
||||
|
||||
if all_healthy {
|
||||
HttpResponse::Ok().body("All servers healthy")
|
||||
if unhealthy_servers.is_empty() {
|
||||
(StatusCode::OK, "All servers healthy").into_response()
|
||||
} else {
|
||||
HttpResponse::ServiceUnavailable()
|
||||
.body(format!("Unhealthy servers: {:?}", unhealthy_servers))
|
||||
(
|
||||
StatusCode::SERVICE_UNAVAILABLE,
|
||||
format!("Unhealthy servers: {:?}", unhealthy_servers),
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
}
|
||||
|
||||
async fn health_generate(&self, client: &Client, req: &HttpRequest) -> HttpResponse {
|
||||
// Test model generation capability by sending to first available worker
|
||||
// Note: This endpoint actually causes the model to generate a token, so we only test one worker
|
||||
self.route_to_first(client, "/health_generate", req).await
|
||||
async fn health_generate(&self, client: &Client, req: Request<Body>) -> Response {
|
||||
self.proxy_get_request(client, req, "health_generate").await
|
||||
}
|
||||
|
||||
async fn get_server_info(&self, client: &Client, req: &HttpRequest) -> HttpResponse {
|
||||
self.route_to_first(client, "/get_server_info", req).await
|
||||
async fn get_server_info(&self, client: &Client, req: Request<Body>) -> Response {
|
||||
self.proxy_get_request(client, req, "get_server_info").await
|
||||
}
|
||||
|
||||
async fn get_models(&self, client: &Client, req: &HttpRequest) -> HttpResponse {
|
||||
self.route_to_first(client, "/v1/models", req).await
|
||||
async fn get_models(&self, client: &Client, req: Request<Body>) -> Response {
|
||||
self.proxy_get_request(client, req, "v1/models").await
|
||||
}
|
||||
|
||||
async fn get_model_info(&self, client: &Client, req: &HttpRequest) -> HttpResponse {
|
||||
self.route_to_first(client, "/get_model_info", req).await
|
||||
async fn get_model_info(&self, client: &Client, req: Request<Body>) -> Response {
|
||||
self.proxy_get_request(client, req, "get_model_info").await
|
||||
}
|
||||
|
||||
async fn route_generate(
|
||||
&self,
|
||||
client: &Client,
|
||||
req: &HttpRequest,
|
||||
body: serde_json::Value,
|
||||
) -> HttpResponse {
|
||||
// Convert JSON to typed request
|
||||
match serde_json::from_value::<crate::openai_api_types::GenerateRequest>(body) {
|
||||
Ok(typed_req) => {
|
||||
self.route_typed_request(client, req, &typed_req, "/generate")
|
||||
.await
|
||||
}
|
||||
Err(e) => HttpResponse::BadRequest().body(format!("Invalid request: {}", e)),
|
||||
}
|
||||
headers: Option<&HeaderMap>,
|
||||
body: &GenerateRequest,
|
||||
) -> Response {
|
||||
self.route_typed_request(client, headers, body, "/generate")
|
||||
.await
|
||||
}
|
||||
|
||||
async fn route_chat(
|
||||
&self,
|
||||
client: &Client,
|
||||
req: &HttpRequest,
|
||||
body: serde_json::Value,
|
||||
) -> HttpResponse {
|
||||
// Convert JSON to typed request
|
||||
match serde_json::from_value::<crate::openai_api_types::ChatCompletionRequest>(body) {
|
||||
Ok(typed_req) => {
|
||||
self.route_typed_request(client, req, &typed_req, "/v1/chat/completions")
|
||||
.await
|
||||
}
|
||||
Err(e) => HttpResponse::BadRequest().body(format!("Invalid request: {}", e)),
|
||||
}
|
||||
headers: Option<&HeaderMap>,
|
||||
body: &ChatCompletionRequest,
|
||||
) -> Response {
|
||||
self.route_typed_request(client, headers, body, "/v1/chat/completions")
|
||||
.await
|
||||
}
|
||||
|
||||
async fn route_completion(
|
||||
&self,
|
||||
client: &Client,
|
||||
req: &HttpRequest,
|
||||
body: serde_json::Value,
|
||||
) -> HttpResponse {
|
||||
// Convert JSON to typed request
|
||||
match serde_json::from_value::<crate::openai_api_types::CompletionRequest>(body) {
|
||||
Ok(typed_req) => {
|
||||
self.route_typed_request(client, req, &typed_req, "/v1/completions")
|
||||
.await
|
||||
}
|
||||
Err(e) => HttpResponse::BadRequest().body(format!("Invalid request: {}", e)),
|
||||
}
|
||||
headers: Option<&HeaderMap>,
|
||||
body: &CompletionRequest,
|
||||
) -> Response {
|
||||
self.route_typed_request(client, headers, body, "/v1/completions")
|
||||
.await
|
||||
}
|
||||
|
||||
async fn flush_cache(&self, client: &Client) -> HttpResponse {
|
||||
async fn flush_cache(&self, client: &Client) -> Response {
|
||||
// Get all worker URLs
|
||||
let worker_urls = self.get_worker_urls();
|
||||
|
||||
@@ -1129,7 +1106,11 @@ impl RouterTrait for Router {
|
||||
Ok(tup) => tup,
|
||||
Err(e) => {
|
||||
error!("Failed to extract dp_rank: {}", e);
|
||||
return HttpResponse::InternalServerError().finish();
|
||||
return (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Failed to extract dp_rank: {}", e),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
};
|
||||
worker_url_prefix
|
||||
@@ -1151,13 +1132,17 @@ impl RouterTrait for Router {
|
||||
});
|
||||
|
||||
if all_success {
|
||||
HttpResponse::Ok().body("Cache flushed on all servers")
|
||||
(StatusCode::OK, "Cache flushed on all servers").into_response()
|
||||
} else {
|
||||
HttpResponse::InternalServerError().body("Cache flush failed on one or more servers")
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
"Cache flush failed on one or more servers",
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
}
|
||||
|
||||
async fn get_worker_loads(&self, client: &Client) -> HttpResponse {
|
||||
async fn get_worker_loads(&self, client: &Client) -> Response {
|
||||
let urls = self.get_worker_urls();
|
||||
let mut loads = Vec::new();
|
||||
|
||||
@@ -1170,16 +1155,17 @@ impl RouterTrait for Router {
|
||||
}));
|
||||
}
|
||||
|
||||
HttpResponse::Ok().json(serde_json::json!({
|
||||
Json(serde_json::json!({
|
||||
"workers": loads
|
||||
}))
|
||||
.into_response()
|
||||
}
|
||||
|
||||
fn router_type(&self) -> &'static str {
|
||||
"regular"
|
||||
}
|
||||
|
||||
fn readiness(&self) -> HttpResponse {
|
||||
fn readiness(&self) -> Response {
|
||||
// Regular router is ready if it has at least one healthy worker
|
||||
let healthy_count = self
|
||||
.workers
|
||||
@@ -1190,17 +1176,22 @@ impl RouterTrait for Router {
|
||||
.count();
|
||||
|
||||
if healthy_count > 0 {
|
||||
HttpResponse::Ok().json(serde_json::json!({
|
||||
Json(serde_json::json!({
|
||||
"status": "ready",
|
||||
"healthy_workers": healthy_count,
|
||||
"total_workers": self.workers.read().unwrap().len()
|
||||
}))
|
||||
.into_response()
|
||||
} else {
|
||||
HttpResponse::ServiceUnavailable().json(serde_json::json!({
|
||||
"status": "not_ready",
|
||||
"reason": "no healthy workers available",
|
||||
"total_workers": self.workers.read().unwrap().len()
|
||||
}))
|
||||
(
|
||||
StatusCode::SERVICE_UNAVAILABLE,
|
||||
Json(serde_json::json!({
|
||||
"status": "not_ready",
|
||||
"reason": "no healthy workers available",
|
||||
"total_workers": self.workers.read().unwrap().len()
|
||||
})),
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user