[router] support /add_worker api (#2369)

This commit is contained in:
Byron Hsu
2024-12-06 01:17:04 -08:00
committed by GitHub
parent 37ee906f61
commit 67b657945a
3 changed files with 134 additions and 18 deletions

View File

@@ -7,18 +7,18 @@ use log::{debug, info};
use std::collections::HashMap;
use std::fmt::Debug;
use std::sync::atomic::AtomicUsize;
use std::sync::{Arc, Mutex};
use std::sync::{Arc, Mutex, RwLock};
use std::thread;
use std::time::Duration;
#[derive(Debug)]
pub enum Router {
RoundRobin {
worker_urls: Vec<String>,
worker_urls: Arc<RwLock<Vec<String>>>,
current_index: AtomicUsize,
},
Random {
worker_urls: Vec<String>,
worker_urls: Arc<RwLock<Vec<String>>>,
},
CacheAware {
/*
@@ -81,7 +81,7 @@ pub enum Router {
Maximum nodes per tree. When exceeded, LRU leaf nodes are evicted
during the next eviction cycle.
*/
worker_urls: Vec<String>,
worker_urls: Arc<RwLock<Vec<String>>>,
tree: Arc<Mutex<Tree>>,
running_queue: Arc<Mutex<HashMap<String, usize>>>,
processed_queue: Arc<Mutex<HashMap<String, usize>>>,
@@ -129,9 +129,11 @@ fn get_text_from_request(body: &Bytes, route: &str) -> String {
impl Router {
pub fn new(worker_urls: Vec<String>, policy_config: PolicyConfig) -> Self {
match policy_config {
PolicyConfig::RandomConfig => Router::Random { worker_urls },
PolicyConfig::RandomConfig => Router::Random {
worker_urls: Arc::new(RwLock::new(worker_urls)),
},
PolicyConfig::RoundRobinConfig => Router::RoundRobin {
worker_urls,
worker_urls: Arc::new(RwLock::new(worker_urls)),
current_index: std::sync::atomic::AtomicUsize::new(0),
},
PolicyConfig::CacheAwareConfig {
@@ -183,7 +185,7 @@ impl Router {
}
Router::CacheAware {
worker_urls,
worker_urls: Arc::new(RwLock::new(worker_urls)),
tree,
running_queue,
processed_queue,
@@ -201,10 +203,10 @@ impl Router {
Router::RoundRobin { worker_urls, .. }
| Router::Random { worker_urls }
| Router::CacheAware { worker_urls, .. } => {
if worker_urls.is_empty() {
if worker_urls.read().unwrap().is_empty() {
None
} else {
Some(worker_urls[0].clone())
Some(worker_urls.read().unwrap()[0].clone())
}
}
}
@@ -228,15 +230,15 @@ impl Router {
.fetch_update(
std::sync::atomic::Ordering::SeqCst,
std::sync::atomic::Ordering::SeqCst,
|x| Some((x + 1) % worker_urls.len()),
|x| Some((x + 1) % worker_urls.read().unwrap().len()),
)
.unwrap();
worker_urls[idx].clone()
worker_urls.read().unwrap()[idx].clone()
}
Router::Random { worker_urls } => {
worker_urls[rand::random::<usize>() % worker_urls.len()].clone()
}
Router::Random { worker_urls } => worker_urls.read().unwrap()
[rand::random::<usize>() % worker_urls.read().unwrap().len()]
.clone(),
Router::CacheAware {
worker_urls,
@@ -277,7 +279,7 @@ impl Router {
.iter()
.min_by_key(|(_url, &count)| count)
.map(|(url, _)| url.clone())
.unwrap_or_else(|| worker_urls[0].clone())
.unwrap_or_else(|| worker_urls.read().unwrap()[0].clone())
} else {
// Use cache-aware routing when load is balanced
let (matched_text, matched_worker) = tree.prefix_match(&text);
@@ -333,7 +335,10 @@ impl Router {
// For non-streaming requests, get response first
let response = match res.bytes().await {
Ok(body) => HttpResponse::build(status).body(body.to_vec()),
Err(_) => HttpResponse::InternalServerError().finish(),
Err(e) => {
let error_msg = format!("Failed to get response body: {}", e);
HttpResponse::InternalServerError().body(error_msg)
}
};
// Then decrement running queue counter if using CacheAware
@@ -379,4 +384,16 @@ impl Router {
}))
}
}
pub fn add_worker(&self, worker_url: String) {
match self {
Router::RoundRobin { worker_urls, .. }
| Router::Random { worker_urls }
| Router::CacheAware { worker_urls, .. } => {
let mut urls = worker_urls.write().unwrap();
info!("Added worker: {}", worker_url);
urls.push(worker_url);
}
}
}
}

View File

@@ -1,9 +1,12 @@
use crate::router::PolicyConfig;
use crate::router::Router;
use actix_web::{get, post, web, App, HttpRequest, HttpResponse, HttpServer, Responder};
use actix_web::{
delete, get, post, put, web, App, HttpRequest, HttpResponse, HttpServer, Responder,
};
use bytes::Bytes;
use env_logger::Builder;
use log::{info, LevelFilter};
use std::collections::HashMap;
use std::io::Write;
#[derive(Debug)]
@@ -128,6 +131,22 @@ async fn v1_completions(
.await
}
#[post("/add_worker")]
async fn add_worker(
query: web::Query<HashMap<String, String>>,
data: web::Data<AppState>,
) -> impl Responder {
let worker_url = match query.get("url") {
Some(url) => url.to_string(),
None => {
return HttpResponse::BadRequest()
.body("Worker URL required. Provide 'url' query parameter")
}
};
data.router.add_worker(worker_url);
HttpResponse::Ok().finish()
}
pub struct ServerConfig {
pub host: String,
pub port: u16,
@@ -183,6 +202,7 @@ pub async fn startup(config: ServerConfig) -> std::io::Result<()> {
.service(health)
.service(health_generate)
.service(get_server_info)
.service(add_worker)
})
.bind((config.host, config.port))?
.run()