[router] Refactor: decouple select and send stage (#2440)
This commit is contained in:
@@ -106,28 +106,6 @@ pub enum PolicyConfig {
|
||||
},
|
||||
}
|
||||
|
||||
fn get_text_from_request(body: &Bytes, route: &str) -> String {
|
||||
// convert body to json
|
||||
let json = serde_json::from_slice::<serde_json::Value>(body).unwrap();
|
||||
|
||||
if route == "generate" {
|
||||
// get the "text" field
|
||||
let text = json.get("text").and_then(|t| t.as_str()).unwrap_or("");
|
||||
return text.to_string();
|
||||
} else if route == "v1/chat/completions" {
|
||||
// get the messages field as raw text
|
||||
if let Some(messages) = json.get("messages") {
|
||||
// Convert messages back to a string, preserving all JSON formatting
|
||||
return serde_json::to_string(messages).unwrap_or_default();
|
||||
}
|
||||
} else if route == "v1/completions" {
|
||||
let prompt = json.get("prompt").and_then(|t| t.as_str()).unwrap_or("");
|
||||
return prompt.to_string();
|
||||
}
|
||||
|
||||
return "".to_string();
|
||||
}
|
||||
|
||||
impl Router {
|
||||
pub fn new(worker_urls: Vec<String>, policy_config: PolicyConfig) -> Result<Self, String> {
|
||||
// Wait until all workers are healthy
|
||||
@@ -204,20 +182,6 @@ impl Router {
|
||||
})
|
||||
}
|
||||
|
||||
pub fn get_first(&self) -> Option<String> {
|
||||
match self {
|
||||
Router::RoundRobin { worker_urls, .. }
|
||||
| Router::Random { worker_urls }
|
||||
| Router::CacheAware { worker_urls, .. } => {
|
||||
if worker_urls.read().unwrap().is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(worker_urls.read().unwrap()[0].clone())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn wait_for_healthy_workers(
|
||||
worker_urls: &[String],
|
||||
timeout_secs: u64,
|
||||
@@ -271,14 +235,76 @@ impl Router {
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn dispatch(
|
||||
fn select_first_worker(&self) -> Result<String, String> {
|
||||
match self {
|
||||
Router::RoundRobin { worker_urls, .. }
|
||||
| Router::Random { worker_urls }
|
||||
| Router::CacheAware { worker_urls, .. } => {
|
||||
if worker_urls.read().unwrap().is_empty() {
|
||||
Err("No workers are available".to_string())
|
||||
} else {
|
||||
Ok(worker_urls.read().unwrap()[0].clone())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn send_request(
|
||||
&self,
|
||||
client: &reqwest::Client,
|
||||
req: HttpRequest,
|
||||
body: Bytes,
|
||||
worker_url: String,
|
||||
route: &str,
|
||||
) -> HttpResponse {
|
||||
let text = get_text_from_request(&body, route);
|
||||
match client.get(format!("{}{}", worker_url, route)).send().await {
|
||||
Ok(res) => {
|
||||
let status = actix_web::http::StatusCode::from_u16(res.status().as_u16())
|
||||
.unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR);
|
||||
|
||||
match res.bytes().await {
|
||||
Ok(body) => HttpResponse::build(status).body(body.to_vec()),
|
||||
Err(e) => HttpResponse::InternalServerError()
|
||||
.body(format!("Failed to read response body: {}", e)),
|
||||
}
|
||||
}
|
||||
Err(e) => HttpResponse::InternalServerError().body(format!(
|
||||
"Failed to send request to worker {}: {}",
|
||||
worker_url, e
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn route_to_first(&self, client: &reqwest::Client, route: &str) -> HttpResponse {
|
||||
match self.select_first_worker() {
|
||||
Ok(worker_url) => self.send_request(client, worker_url, route).await,
|
||||
Err(e) => HttpResponse::InternalServerError().body(e),
|
||||
}
|
||||
}
|
||||
|
||||
fn get_text_from_request(&self, body: &Bytes, route: &str) -> String {
|
||||
// convert body to json
|
||||
let json = serde_json::from_slice::<serde_json::Value>(body).unwrap();
|
||||
|
||||
if route == "generate" {
|
||||
// get the "text" field
|
||||
let text = json.get("text").and_then(|t| t.as_str()).unwrap_or("");
|
||||
return text.to_string();
|
||||
} else if route == "v1/chat/completions" {
|
||||
// get the messages field as raw text
|
||||
if let Some(messages) = json.get("messages") {
|
||||
// Convert messages back to a string, preserving all JSON formatting
|
||||
return serde_json::to_string(messages).unwrap_or_default();
|
||||
}
|
||||
} else if route == "v1/completions" {
|
||||
let prompt = json.get("prompt").and_then(|t| t.as_str()).unwrap_or("");
|
||||
return prompt.to_string();
|
||||
}
|
||||
|
||||
return "".to_string();
|
||||
}
|
||||
|
||||
// TODO: return Result<String, String> instead of panicking
|
||||
fn select_generate_worker(&self, body: &Bytes, route: &str) -> String {
|
||||
let text = self.get_text_from_request(&body, route);
|
||||
|
||||
let worker_url = match self {
|
||||
Router::RoundRobin {
|
||||
@@ -366,12 +392,23 @@ impl Router {
|
||||
}
|
||||
};
|
||||
|
||||
worker_url
|
||||
}
|
||||
|
||||
async fn send_generate_request(
|
||||
&self,
|
||||
client: &reqwest::Client,
|
||||
req: HttpRequest,
|
||||
body: Bytes,
|
||||
route: &str,
|
||||
worker_url: &str,
|
||||
) -> HttpResponse {
|
||||
let is_stream = serde_json::from_slice::<serde_json::Value>(&body)
|
||||
.map(|v| v.get("stream").and_then(|s| s.as_bool()).unwrap_or(false))
|
||||
.unwrap_or(false);
|
||||
|
||||
let res = match client
|
||||
.post(format!("{}/{}", worker_url.clone(), route))
|
||||
.post(format!("{}{}", worker_url, route))
|
||||
.header(
|
||||
"Content-Type",
|
||||
req.headers()
|
||||
@@ -403,7 +440,7 @@ impl Router {
|
||||
// Then decrement running queue counter if using CacheAware
|
||||
if let Router::CacheAware { running_queue, .. } = self {
|
||||
if let Ok(mut queue) = running_queue.lock() {
|
||||
if let Some(count) = queue.get_mut(&worker_url) {
|
||||
if let Some(count) = queue.get_mut(worker_url) {
|
||||
*count = count.saturating_sub(1);
|
||||
}
|
||||
}
|
||||
@@ -412,7 +449,7 @@ impl Router {
|
||||
response
|
||||
} else if let Router::CacheAware { running_queue, .. } = self {
|
||||
let running_queue = Arc::clone(running_queue);
|
||||
let worker_url = worker_url.clone();
|
||||
let worker_url = worker_url.to_string();
|
||||
|
||||
HttpResponse::build(status)
|
||||
.insert_header((CONTENT_TYPE, HeaderValue::from_static("text/event-stream")))
|
||||
@@ -431,7 +468,7 @@ impl Router {
|
||||
let mut locked_queue = running_queue.lock().unwrap();
|
||||
let count = locked_queue.get_mut(&worker_url).unwrap();
|
||||
*count = count.saturating_sub(1);
|
||||
debug!("streaming is done!!")
|
||||
debug!("Streaming is done!!")
|
||||
}
|
||||
}),
|
||||
)
|
||||
@@ -444,6 +481,18 @@ impl Router {
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn route_generate_request(
|
||||
&self,
|
||||
client: &reqwest::Client,
|
||||
req: HttpRequest,
|
||||
body: Bytes,
|
||||
route: &str,
|
||||
) -> HttpResponse {
|
||||
let worker_url = self.select_generate_worker(&body, route);
|
||||
self.send_generate_request(client, req, body, route, &worker_url)
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn add_worker(&self, worker_url: String) -> Result<String, String> {
|
||||
let interval_secs = 10; // check every 10 seconds
|
||||
let timeout_secs = 300; // 5 minutes
|
||||
|
||||
@@ -29,84 +29,41 @@ impl AppState {
|
||||
}
|
||||
}
|
||||
|
||||
async fn forward_request(
|
||||
client: &reqwest::Client,
|
||||
worker_url: String,
|
||||
route: String,
|
||||
) -> HttpResponse {
|
||||
match client.get(format!("{}{}", worker_url, route)).send().await {
|
||||
Ok(res) => {
|
||||
let status = actix_web::http::StatusCode::from_u16(res.status().as_u16())
|
||||
.unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR);
|
||||
|
||||
// print the status
|
||||
println!(
|
||||
"Forwarding Request Worker URL: {}, Route: {}, Status: {}",
|
||||
worker_url, route, status
|
||||
);
|
||||
match res.bytes().await {
|
||||
Ok(body) => HttpResponse::build(status).body(body.to_vec()),
|
||||
Err(_) => HttpResponse::InternalServerError().finish(),
|
||||
}
|
||||
}
|
||||
Err(_) => HttpResponse::InternalServerError().finish(),
|
||||
}
|
||||
}
|
||||
|
||||
#[get("/health")]
|
||||
async fn health(data: web::Data<AppState>) -> impl Responder {
|
||||
let worker_url = match data.router.get_first() {
|
||||
Some(url) => url,
|
||||
None => return HttpResponse::InternalServerError().finish(),
|
||||
};
|
||||
|
||||
forward_request(&data.client, worker_url, "/health".to_string()).await
|
||||
data.router.route_to_first(&data.client, "/health").await
|
||||
}
|
||||
|
||||
#[get("/health_generate")]
|
||||
async fn health_generate(data: web::Data<AppState>) -> impl Responder {
|
||||
let worker_url = match data.router.get_first() {
|
||||
Some(url) => url,
|
||||
None => return HttpResponse::InternalServerError().finish(),
|
||||
};
|
||||
|
||||
forward_request(&data.client, worker_url, "/health_generate".to_string()).await
|
||||
data.router
|
||||
.route_to_first(&data.client, "/health_generate")
|
||||
.await
|
||||
}
|
||||
|
||||
#[get("/get_server_info")]
|
||||
async fn get_server_info(data: web::Data<AppState>) -> impl Responder {
|
||||
let worker_url = match data.router.get_first() {
|
||||
Some(url) => url,
|
||||
None => return HttpResponse::InternalServerError().finish(),
|
||||
};
|
||||
|
||||
forward_request(&data.client, worker_url, "/get_server_info".to_string()).await
|
||||
data.router
|
||||
.route_to_first(&data.client, "/get_server_info")
|
||||
.await
|
||||
}
|
||||
|
||||
#[get("/v1/models")]
|
||||
async fn v1_models(data: web::Data<AppState>) -> impl Responder {
|
||||
let worker_url = match data.router.get_first() {
|
||||
Some(url) => url,
|
||||
None => return HttpResponse::InternalServerError().finish(),
|
||||
};
|
||||
|
||||
forward_request(&data.client, worker_url, "/v1/models".to_string()).await
|
||||
data.router.route_to_first(&data.client, "/v1/models").await
|
||||
}
|
||||
|
||||
#[get("/get_model_info")]
|
||||
async fn get_model_info(data: web::Data<AppState>) -> impl Responder {
|
||||
let worker_url = match data.router.get_first() {
|
||||
Some(url) => url,
|
||||
None => return HttpResponse::InternalServerError().finish(),
|
||||
};
|
||||
|
||||
forward_request(&data.client, worker_url, "/get_model_info".to_string()).await
|
||||
data.router
|
||||
.route_to_first(&data.client, "/get_model_info")
|
||||
.await
|
||||
}
|
||||
|
||||
#[post("/generate")]
|
||||
async fn generate(req: HttpRequest, body: Bytes, data: web::Data<AppState>) -> impl Responder {
|
||||
data.router
|
||||
.dispatch(&data.client, req, body, "generate")
|
||||
.route_generate_request(&data.client, req, body, "/generate")
|
||||
.await
|
||||
}
|
||||
|
||||
@@ -117,7 +74,7 @@ async fn v1_chat_completions(
|
||||
data: web::Data<AppState>,
|
||||
) -> impl Responder {
|
||||
data.router
|
||||
.dispatch(&data.client, req, body, "v1/chat/completions")
|
||||
.route_generate_request(&data.client, req, body, "/v1/chat/completions")
|
||||
.await
|
||||
}
|
||||
|
||||
@@ -128,7 +85,7 @@ async fn v1_completions(
|
||||
data: web::Data<AppState>,
|
||||
) -> impl Responder {
|
||||
data.router
|
||||
.dispatch(&data.client, req, body, "v1/completions")
|
||||
.route_generate_request(&data.client, req, body, "/v1/completions")
|
||||
.await
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user