[Router]fix: fix get_load missing api_key (#10385)

This commit is contained in:
Jimmy
2025-09-22 03:28:38 +08:00
committed by GitHub
parent 12d6cf18f0
commit 56321e9fc2
21 changed files with 378 additions and 111 deletions

View File

@@ -38,6 +38,7 @@ pub struct Router {
worker_startup_timeout_secs: u64,
worker_startup_check_interval_secs: u64,
dp_aware: bool,
#[allow(dead_code)]
api_key: Option<String>,
retry_config: RetryConfig,
circuit_breaker_config: CircuitBreakerConfig,
@@ -71,7 +72,6 @@ impl Router {
};
// Cache-aware policies are initialized in WorkerInitializer
// Setup load monitoring for PowerOfTwo policy
let (tx, rx) = tokio::sync::watch::channel(HashMap::new());
let worker_loads = Arc::new(rx);
@@ -82,6 +82,14 @@ impl Router {
// Check if default policy is power_of_two for load monitoring
let load_monitor_handle = if default_policy.name() == "power_of_two" {
let monitor_urls = worker_urls.clone();
let monitor_api_keys = monitor_urls
.iter()
.map(|url| {
ctx.worker_registry
.get_by_url(url)
.and_then(|w| w.api_key().clone())
})
.collect::<Vec<Option<String>>>();
let monitor_interval = ctx.router_config.worker_startup_check_interval_secs;
let policy_clone = default_policy.clone();
let client_clone = ctx.client.clone();
@@ -89,6 +97,7 @@ impl Router {
Some(Arc::new(tokio::spawn(async move {
Self::monitor_worker_loads(
monitor_urls,
monitor_api_keys,
tx,
monitor_interval,
policy_clone,
@@ -912,7 +921,11 @@ impl Router {
}
}
pub async fn add_worker(&self, worker_url: &str) -> Result<String, String> {
pub async fn add_worker(
&self,
worker_url: &str,
api_key: &Option<String>,
) -> Result<String, String> {
let start_time = std::time::Instant::now();
let client = reqwest::Client::builder()
.timeout(Duration::from_secs(self.worker_startup_timeout_secs))
@@ -938,7 +951,7 @@ impl Router {
// Need to contact the worker to extract the dp_size,
// and add them as multiple workers
let url_vec = vec![String::from(worker_url)];
let dp_url_vec = Self::get_dp_aware_workers(&url_vec, &self.api_key)
let dp_url_vec = Self::get_dp_aware_workers(&url_vec, api_key)
.map_err(|e| format!("Failed to get dp-aware workers: {}", e))?;
let mut worker_added: bool = false;
for dp_url in &dp_url_vec {
@@ -948,10 +961,18 @@ impl Router {
}
info!("Added worker: {}", dp_url);
// TODO: In IGW mode, fetch model_id from worker's /get_model_info endpoint
let new_worker = BasicWorkerBuilder::new(dp_url.to_string())
.worker_type(WorkerType::Regular)
.circuit_breaker_config(self.circuit_breaker_config.clone())
.build();
let new_worker_builder =
BasicWorkerBuilder::new(dp_url.to_string())
.worker_type(WorkerType::Regular)
.circuit_breaker_config(
self.circuit_breaker_config.clone(),
);
let new_worker = if let Some(api_key) = api_key {
new_worker_builder.api_key(api_key).build()
} else {
new_worker_builder.build()
};
let worker_arc = Arc::new(new_worker);
self.worker_registry.register(worker_arc.clone());
@@ -978,10 +999,16 @@ impl Router {
info!("Added worker: {}", worker_url);
// TODO: In IGW mode, fetch model_id from worker's /get_model_info endpoint
let new_worker = BasicWorkerBuilder::new(worker_url.to_string())
.worker_type(WorkerType::Regular)
.circuit_breaker_config(self.circuit_breaker_config.clone())
.build();
let new_worker_builder =
BasicWorkerBuilder::new(worker_url.to_string())
.worker_type(WorkerType::Regular)
.circuit_breaker_config(self.circuit_breaker_config.clone());
let new_worker = if let Some(api_key) = api_key {
new_worker_builder.api_key(api_key).build()
} else {
new_worker_builder.build()
};
let worker_arc = Arc::new(new_worker);
self.worker_registry.register(worker_arc.clone());
@@ -1094,7 +1121,7 @@ impl Router {
}
}
async fn get_worker_load(&self, worker_url: &str) -> Option<isize> {
async fn get_worker_load(&self, worker_url: &str, api_key: &Option<String>) -> Option<isize> {
let worker_url = if self.dp_aware {
// Need to extract the URL from "http://host:port@dp_rank"
let (worker_url_prefix, _dp_rank) = match Self::extract_dp_rank(worker_url) {
@@ -1109,12 +1136,12 @@ impl Router {
worker_url
};
match self
.client
.get(format!("{}/get_load", worker_url))
.send()
.await
{
let mut req_builder = self.client.get(format!("{}/get_load", worker_url));
if let Some(key) = api_key {
req_builder = req_builder.bearer_auth(key);
}
match req_builder.send().await {
Ok(res) if res.status().is_success() => match res.bytes().await {
Ok(bytes) => match serde_json::from_slice::<serde_json::Value>(&bytes) {
Ok(data) => data
@@ -1149,6 +1176,7 @@ impl Router {
// Background task to monitor worker loads
async fn monitor_worker_loads(
worker_urls: Vec<String>,
worker_api_keys: Vec<Option<String>>,
tx: tokio::sync::watch::Sender<HashMap<String, isize>>,
interval_secs: u64,
policy: Arc<dyn LoadBalancingPolicy>,
@@ -1160,8 +1188,8 @@ impl Router {
interval.tick().await;
let mut loads = HashMap::new();
for url in &worker_urls {
if let Some(load) = Self::get_worker_load_static(&client, url).await {
for (url, api_key) in worker_urls.iter().zip(worker_api_keys.iter()) {
if let Some(load) = Self::get_worker_load_static(&client, url, api_key).await {
loads.insert(url.clone(), load);
}
}
@@ -1179,7 +1207,11 @@ impl Router {
}
// Static version of get_worker_load for use in monitoring task
async fn get_worker_load_static(client: &reqwest::Client, worker_url: &str) -> Option<isize> {
async fn get_worker_load_static(
client: &reqwest::Client,
worker_url: &str,
api_key: &Option<String>,
) -> Option<isize> {
let worker_url = if worker_url.contains("@") {
// Need to extract the URL from "http://host:port@dp_rank"
let (worker_url_prefix, _dp_rank) = match Self::extract_dp_rank(worker_url) {
@@ -1194,7 +1226,11 @@ impl Router {
worker_url
};
match client.get(format!("{}/get_load", worker_url)).send().await {
let mut req_builder = client.get(format!("{}/get_load", worker_url));
if let Some(key) = api_key {
req_builder = req_builder.bearer_auth(key);
}
match req_builder.send().await {
Ok(res) if res.status().is_success() => match res.bytes().await {
Ok(bytes) => match serde_json::from_slice::<serde_json::Value>(&bytes) {
Ok(data) => data
@@ -1250,8 +1286,12 @@ use async_trait::async_trait;
#[async_trait]
impl WorkerManagement for Router {
async fn add_worker(&self, worker_url: &str) -> Result<String, String> {
Router::add_worker(self, worker_url).await
async fn add_worker(
&self,
worker_url: &str,
api_key: &Option<String>,
) -> Result<String, String> {
Router::add_worker(self, worker_url, api_key).await
}
fn remove_worker(&self, worker_url: &str) {
@@ -1457,12 +1497,12 @@ impl RouterTrait for Router {
}
async fn get_worker_loads(&self) -> Response {
let urls = self.get_worker_urls();
let urls_with_key = self.worker_registry.get_all_urls_with_api_key();
let mut loads = Vec::new();
// Get loads from all workers
for url in &urls {
let load = self.get_worker_load(url).await.unwrap_or(-1);
for (url, api_key) in &urls_with_key {
let load = self.get_worker_load(url, api_key).await.unwrap_or(-1);
loads.push(serde_json::json!({
"worker": url,
"load": load
@@ -1521,9 +1561,11 @@ mod tests {
// Register test workers
let worker1 = BasicWorkerBuilder::new("http://worker1:8080")
.worker_type(WorkerType::Regular)
.api_key("test_api_key")
.build();
let worker2 = BasicWorkerBuilder::new("http://worker2:8080")
.worker_type(WorkerType::Regular)
.api_key("test_api_key")
.build();
worker_registry.register(Arc::new(worker1));
worker_registry.register(Arc::new(worker2));