[Router]fix: fix get_load missing api_key (#10385)
This commit is contained in:
@@ -38,6 +38,7 @@ pub struct Router {
|
||||
worker_startup_timeout_secs: u64,
|
||||
worker_startup_check_interval_secs: u64,
|
||||
dp_aware: bool,
|
||||
#[allow(dead_code)]
|
||||
api_key: Option<String>,
|
||||
retry_config: RetryConfig,
|
||||
circuit_breaker_config: CircuitBreakerConfig,
|
||||
@@ -71,7 +72,6 @@ impl Router {
|
||||
};
|
||||
|
||||
// Cache-aware policies are initialized in WorkerInitializer
|
||||
|
||||
// Setup load monitoring for PowerOfTwo policy
|
||||
let (tx, rx) = tokio::sync::watch::channel(HashMap::new());
|
||||
let worker_loads = Arc::new(rx);
|
||||
@@ -82,6 +82,14 @@ impl Router {
|
||||
// Check if default policy is power_of_two for load monitoring
|
||||
let load_monitor_handle = if default_policy.name() == "power_of_two" {
|
||||
let monitor_urls = worker_urls.clone();
|
||||
let monitor_api_keys = monitor_urls
|
||||
.iter()
|
||||
.map(|url| {
|
||||
ctx.worker_registry
|
||||
.get_by_url(url)
|
||||
.and_then(|w| w.api_key().clone())
|
||||
})
|
||||
.collect::<Vec<Option<String>>>();
|
||||
let monitor_interval = ctx.router_config.worker_startup_check_interval_secs;
|
||||
let policy_clone = default_policy.clone();
|
||||
let client_clone = ctx.client.clone();
|
||||
@@ -89,6 +97,7 @@ impl Router {
|
||||
Some(Arc::new(tokio::spawn(async move {
|
||||
Self::monitor_worker_loads(
|
||||
monitor_urls,
|
||||
monitor_api_keys,
|
||||
tx,
|
||||
monitor_interval,
|
||||
policy_clone,
|
||||
@@ -912,7 +921,11 @@ impl Router {
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn add_worker(&self, worker_url: &str) -> Result<String, String> {
|
||||
pub async fn add_worker(
|
||||
&self,
|
||||
worker_url: &str,
|
||||
api_key: &Option<String>,
|
||||
) -> Result<String, String> {
|
||||
let start_time = std::time::Instant::now();
|
||||
let client = reqwest::Client::builder()
|
||||
.timeout(Duration::from_secs(self.worker_startup_timeout_secs))
|
||||
@@ -938,7 +951,7 @@ impl Router {
|
||||
// Need to contact the worker to extract the dp_size,
|
||||
// and add them as multiple workers
|
||||
let url_vec = vec![String::from(worker_url)];
|
||||
let dp_url_vec = Self::get_dp_aware_workers(&url_vec, &self.api_key)
|
||||
let dp_url_vec = Self::get_dp_aware_workers(&url_vec, api_key)
|
||||
.map_err(|e| format!("Failed to get dp-aware workers: {}", e))?;
|
||||
let mut worker_added: bool = false;
|
||||
for dp_url in &dp_url_vec {
|
||||
@@ -948,10 +961,18 @@ impl Router {
|
||||
}
|
||||
info!("Added worker: {}", dp_url);
|
||||
// TODO: In IGW mode, fetch model_id from worker's /get_model_info endpoint
|
||||
let new_worker = BasicWorkerBuilder::new(dp_url.to_string())
|
||||
.worker_type(WorkerType::Regular)
|
||||
.circuit_breaker_config(self.circuit_breaker_config.clone())
|
||||
.build();
|
||||
let new_worker_builder =
|
||||
BasicWorkerBuilder::new(dp_url.to_string())
|
||||
.worker_type(WorkerType::Regular)
|
||||
.circuit_breaker_config(
|
||||
self.circuit_breaker_config.clone(),
|
||||
);
|
||||
|
||||
let new_worker = if let Some(api_key) = api_key {
|
||||
new_worker_builder.api_key(api_key).build()
|
||||
} else {
|
||||
new_worker_builder.build()
|
||||
};
|
||||
|
||||
let worker_arc = Arc::new(new_worker);
|
||||
self.worker_registry.register(worker_arc.clone());
|
||||
@@ -978,10 +999,16 @@ impl Router {
|
||||
info!("Added worker: {}", worker_url);
|
||||
|
||||
// TODO: In IGW mode, fetch model_id from worker's /get_model_info endpoint
|
||||
let new_worker = BasicWorkerBuilder::new(worker_url.to_string())
|
||||
.worker_type(WorkerType::Regular)
|
||||
.circuit_breaker_config(self.circuit_breaker_config.clone())
|
||||
.build();
|
||||
let new_worker_builder =
|
||||
BasicWorkerBuilder::new(worker_url.to_string())
|
||||
.worker_type(WorkerType::Regular)
|
||||
.circuit_breaker_config(self.circuit_breaker_config.clone());
|
||||
|
||||
let new_worker = if let Some(api_key) = api_key {
|
||||
new_worker_builder.api_key(api_key).build()
|
||||
} else {
|
||||
new_worker_builder.build()
|
||||
};
|
||||
|
||||
let worker_arc = Arc::new(new_worker);
|
||||
self.worker_registry.register(worker_arc.clone());
|
||||
@@ -1094,7 +1121,7 @@ impl Router {
|
||||
}
|
||||
}
|
||||
|
||||
async fn get_worker_load(&self, worker_url: &str) -> Option<isize> {
|
||||
async fn get_worker_load(&self, worker_url: &str, api_key: &Option<String>) -> Option<isize> {
|
||||
let worker_url = if self.dp_aware {
|
||||
// Need to extract the URL from "http://host:port@dp_rank"
|
||||
let (worker_url_prefix, _dp_rank) = match Self::extract_dp_rank(worker_url) {
|
||||
@@ -1109,12 +1136,12 @@ impl Router {
|
||||
worker_url
|
||||
};
|
||||
|
||||
match self
|
||||
.client
|
||||
.get(format!("{}/get_load", worker_url))
|
||||
.send()
|
||||
.await
|
||||
{
|
||||
let mut req_builder = self.client.get(format!("{}/get_load", worker_url));
|
||||
if let Some(key) = api_key {
|
||||
req_builder = req_builder.bearer_auth(key);
|
||||
}
|
||||
|
||||
match req_builder.send().await {
|
||||
Ok(res) if res.status().is_success() => match res.bytes().await {
|
||||
Ok(bytes) => match serde_json::from_slice::<serde_json::Value>(&bytes) {
|
||||
Ok(data) => data
|
||||
@@ -1149,6 +1176,7 @@ impl Router {
|
||||
// Background task to monitor worker loads
|
||||
async fn monitor_worker_loads(
|
||||
worker_urls: Vec<String>,
|
||||
worker_api_keys: Vec<Option<String>>,
|
||||
tx: tokio::sync::watch::Sender<HashMap<String, isize>>,
|
||||
interval_secs: u64,
|
||||
policy: Arc<dyn LoadBalancingPolicy>,
|
||||
@@ -1160,8 +1188,8 @@ impl Router {
|
||||
interval.tick().await;
|
||||
|
||||
let mut loads = HashMap::new();
|
||||
for url in &worker_urls {
|
||||
if let Some(load) = Self::get_worker_load_static(&client, url).await {
|
||||
for (url, api_key) in worker_urls.iter().zip(worker_api_keys.iter()) {
|
||||
if let Some(load) = Self::get_worker_load_static(&client, url, api_key).await {
|
||||
loads.insert(url.clone(), load);
|
||||
}
|
||||
}
|
||||
@@ -1179,7 +1207,11 @@ impl Router {
|
||||
}
|
||||
|
||||
// Static version of get_worker_load for use in monitoring task
|
||||
async fn get_worker_load_static(client: &reqwest::Client, worker_url: &str) -> Option<isize> {
|
||||
async fn get_worker_load_static(
|
||||
client: &reqwest::Client,
|
||||
worker_url: &str,
|
||||
api_key: &Option<String>,
|
||||
) -> Option<isize> {
|
||||
let worker_url = if worker_url.contains("@") {
|
||||
// Need to extract the URL from "http://host:port@dp_rank"
|
||||
let (worker_url_prefix, _dp_rank) = match Self::extract_dp_rank(worker_url) {
|
||||
@@ -1194,7 +1226,11 @@ impl Router {
|
||||
worker_url
|
||||
};
|
||||
|
||||
match client.get(format!("{}/get_load", worker_url)).send().await {
|
||||
let mut req_builder = client.get(format!("{}/get_load", worker_url));
|
||||
if let Some(key) = api_key {
|
||||
req_builder = req_builder.bearer_auth(key);
|
||||
}
|
||||
match req_builder.send().await {
|
||||
Ok(res) if res.status().is_success() => match res.bytes().await {
|
||||
Ok(bytes) => match serde_json::from_slice::<serde_json::Value>(&bytes) {
|
||||
Ok(data) => data
|
||||
@@ -1250,8 +1286,12 @@ use async_trait::async_trait;
|
||||
|
||||
#[async_trait]
|
||||
impl WorkerManagement for Router {
|
||||
async fn add_worker(&self, worker_url: &str) -> Result<String, String> {
|
||||
Router::add_worker(self, worker_url).await
|
||||
async fn add_worker(
|
||||
&self,
|
||||
worker_url: &str,
|
||||
api_key: &Option<String>,
|
||||
) -> Result<String, String> {
|
||||
Router::add_worker(self, worker_url, api_key).await
|
||||
}
|
||||
|
||||
fn remove_worker(&self, worker_url: &str) {
|
||||
@@ -1457,12 +1497,12 @@ impl RouterTrait for Router {
|
||||
}
|
||||
|
||||
async fn get_worker_loads(&self) -> Response {
|
||||
let urls = self.get_worker_urls();
|
||||
let urls_with_key = self.worker_registry.get_all_urls_with_api_key();
|
||||
let mut loads = Vec::new();
|
||||
|
||||
// Get loads from all workers
|
||||
for url in &urls {
|
||||
let load = self.get_worker_load(url).await.unwrap_or(-1);
|
||||
for (url, api_key) in &urls_with_key {
|
||||
let load = self.get_worker_load(url, api_key).await.unwrap_or(-1);
|
||||
loads.push(serde_json::json!({
|
||||
"worker": url,
|
||||
"load": load
|
||||
@@ -1521,9 +1561,11 @@ mod tests {
|
||||
// Register test workers
|
||||
let worker1 = BasicWorkerBuilder::new("http://worker1:8080")
|
||||
.worker_type(WorkerType::Regular)
|
||||
.api_key("test_api_key")
|
||||
.build();
|
||||
let worker2 = BasicWorkerBuilder::new("http://worker2:8080")
|
||||
.worker_type(WorkerType::Regular)
|
||||
.api_key("test_api_key")
|
||||
.build();
|
||||
worker_registry.register(Arc::new(worker1));
|
||||
worker_registry.register(Arc::new(worker2));
|
||||
|
||||
Reference in New Issue
Block a user