[Router]fix: fix get_load missing api_key (#10385)
This commit is contained in:
@@ -24,7 +24,8 @@ static WORKER_CLIENT: LazyLock<reqwest::Client> = LazyLock::new(|| {
|
||||
pub trait Worker: Send + Sync + fmt::Debug {
|
||||
/// Get the worker's URL
|
||||
fn url(&self) -> &str;
|
||||
|
||||
/// Get the worker's API key
|
||||
fn api_key(&self) -> &Option<String>;
|
||||
/// Get the worker's type (Regular, Prefill, or Decode)
|
||||
fn worker_type(&self) -> WorkerType;
|
||||
|
||||
@@ -323,6 +324,8 @@ pub struct WorkerMetadata {
|
||||
pub labels: std::collections::HashMap<String, String>,
|
||||
/// Health check configuration
|
||||
pub health_config: HealthConfig,
|
||||
/// API key
|
||||
pub api_key: Option<String>,
|
||||
}
|
||||
|
||||
/// Basic worker implementation
|
||||
@@ -379,6 +382,10 @@ impl Worker for BasicWorker {
|
||||
&self.metadata.url
|
||||
}
|
||||
|
||||
fn api_key(&self) -> &Option<String> {
|
||||
&self.metadata.api_key
|
||||
}
|
||||
|
||||
fn worker_type(&self) -> WorkerType {
|
||||
self.metadata.worker_type.clone()
|
||||
}
|
||||
@@ -548,6 +555,10 @@ impl Worker for DPAwareWorker {
|
||||
self.base_worker.url()
|
||||
}
|
||||
|
||||
fn api_key(&self) -> &Option<String> {
|
||||
self.base_worker.api_key()
|
||||
}
|
||||
|
||||
fn worker_type(&self) -> WorkerType {
|
||||
self.base_worker.worker_type()
|
||||
}
|
||||
@@ -650,19 +661,21 @@ impl WorkerFactory {
|
||||
dp_rank: usize,
|
||||
dp_size: usize,
|
||||
worker_type: WorkerType,
|
||||
api_key: Option<String>,
|
||||
) -> Box<dyn Worker> {
|
||||
Box::new(
|
||||
DPAwareWorkerBuilder::new(base_url, dp_rank, dp_size)
|
||||
.worker_type(worker_type)
|
||||
.build(),
|
||||
)
|
||||
let mut builder =
|
||||
DPAwareWorkerBuilder::new(base_url, dp_rank, dp_size).worker_type(worker_type);
|
||||
if let Some(api_key) = api_key {
|
||||
builder = builder.api_key(api_key);
|
||||
}
|
||||
Box::new(builder.build())
|
||||
}
|
||||
#[allow(dead_code)]
|
||||
/// Get DP size from a worker
|
||||
async fn get_worker_dp_size(url: &str, api_key: &Option<String>) -> WorkerResult<usize> {
|
||||
let mut req_builder = WORKER_CLIENT.get(format!("{}/get_server_info", url));
|
||||
|
||||
if let Some(key) = api_key {
|
||||
if let Some(key) = &api_key {
|
||||
req_builder = req_builder.bearer_auth(key);
|
||||
}
|
||||
|
||||
@@ -708,14 +721,18 @@ impl WorkerFactory {
|
||||
}
|
||||
|
||||
/// Convert a list of worker URLs to worker trait objects
|
||||
pub fn urls_to_workers(urls: Vec<String>) -> Vec<Box<dyn Worker>> {
|
||||
pub fn urls_to_workers(urls: Vec<String>, api_key: Option<String>) -> Vec<Box<dyn Worker>> {
|
||||
urls.into_iter()
|
||||
.map(|url| {
|
||||
Box::new(
|
||||
BasicWorkerBuilder::new(url)
|
||||
.worker_type(WorkerType::Regular)
|
||||
.build(),
|
||||
) as Box<dyn Worker>
|
||||
let worker_builder = BasicWorkerBuilder::new(url).worker_type(WorkerType::Regular);
|
||||
|
||||
let worker = if let Some(ref api_key) = api_key {
|
||||
worker_builder.api_key(api_key.clone()).build()
|
||||
} else {
|
||||
worker_builder.build()
|
||||
};
|
||||
|
||||
Box::new(worker) as Box<dyn Worker>
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
@@ -961,6 +978,7 @@ mod tests {
|
||||
use crate::core::BasicWorkerBuilder;
|
||||
let worker = BasicWorkerBuilder::new("http://test:8080")
|
||||
.worker_type(WorkerType::Regular)
|
||||
.api_key("test_api_key")
|
||||
.build();
|
||||
assert_eq!(worker.url(), "http://test:8080");
|
||||
assert_eq!(worker.worker_type(), WorkerType::Regular);
|
||||
@@ -998,6 +1016,7 @@ mod tests {
|
||||
let worker = BasicWorkerBuilder::new("http://test:8080")
|
||||
.worker_type(WorkerType::Regular)
|
||||
.health_config(custom_config.clone())
|
||||
.api_key("test_api_key")
|
||||
.build();
|
||||
|
||||
assert_eq!(worker.metadata().health_config.timeout_secs, 15);
|
||||
@@ -1011,6 +1030,7 @@ mod tests {
|
||||
use crate::core::BasicWorkerBuilder;
|
||||
let worker = BasicWorkerBuilder::new("http://worker1:8080")
|
||||
.worker_type(WorkerType::Regular)
|
||||
.api_key("test_api_key")
|
||||
.build();
|
||||
assert_eq!(worker.url(), "http://worker1:8080");
|
||||
}
|
||||
@@ -1020,6 +1040,7 @@ mod tests {
|
||||
use crate::core::BasicWorkerBuilder;
|
||||
let regular = BasicWorkerBuilder::new("http://test:8080")
|
||||
.worker_type(WorkerType::Regular)
|
||||
.api_key("test_api_key")
|
||||
.build();
|
||||
assert_eq!(regular.worker_type(), WorkerType::Regular);
|
||||
|
||||
@@ -1027,6 +1048,7 @@ mod tests {
|
||||
.worker_type(WorkerType::Prefill {
|
||||
bootstrap_port: Some(9090),
|
||||
})
|
||||
.api_key("test_api_key")
|
||||
.build();
|
||||
assert_eq!(
|
||||
prefill.worker_type(),
|
||||
@@ -1037,6 +1059,7 @@ mod tests {
|
||||
|
||||
let decode = BasicWorkerBuilder::new("http://test:8080")
|
||||
.worker_type(WorkerType::Decode)
|
||||
.api_key("test_api_key")
|
||||
.build();
|
||||
assert_eq!(decode.worker_type(), WorkerType::Decode);
|
||||
}
|
||||
@@ -1065,6 +1088,7 @@ mod tests {
|
||||
use crate::core::BasicWorkerBuilder;
|
||||
let worker = BasicWorkerBuilder::new("http://test:8080")
|
||||
.worker_type(WorkerType::Regular)
|
||||
.api_key("test_api_key")
|
||||
.build();
|
||||
|
||||
// Initial load is 0
|
||||
@@ -1350,7 +1374,7 @@ mod tests {
|
||||
fn test_urls_to_workers() {
|
||||
let urls = vec!["http://w1:8080".to_string(), "http://w2:8080".to_string()];
|
||||
|
||||
let workers = urls_to_workers(urls);
|
||||
let workers = urls_to_workers(urls, Some("test_api_key".to_string()));
|
||||
assert_eq!(workers.len(), 2);
|
||||
assert_eq!(workers[0].url(), "http://w1:8080");
|
||||
assert_eq!(workers[1].url(), "http://w2:8080");
|
||||
@@ -1547,6 +1571,7 @@ mod tests {
|
||||
1,
|
||||
4,
|
||||
WorkerType::Regular,
|
||||
Some("test_api_key".to_string()),
|
||||
);
|
||||
|
||||
assert_eq!(worker.url(), "http://worker1:8080@1");
|
||||
@@ -1565,6 +1590,7 @@ mod tests {
|
||||
WorkerType::Prefill {
|
||||
bootstrap_port: Some(8090),
|
||||
},
|
||||
Some("test_api_key".to_string()),
|
||||
);
|
||||
|
||||
assert_eq!(worker.url(), "http://worker1:8080@0");
|
||||
@@ -1680,8 +1706,13 @@ mod tests {
|
||||
.worker_type(WorkerType::Decode)
|
||||
.build(),
|
||||
);
|
||||
let dp_aware_regular =
|
||||
WorkerFactory::create_dp_aware("http://dp:8080".to_string(), 0, 2, WorkerType::Regular);
|
||||
let dp_aware_regular = WorkerFactory::create_dp_aware(
|
||||
"http://dp:8080".to_string(),
|
||||
0,
|
||||
2,
|
||||
WorkerType::Regular,
|
||||
Some("test_api_key".to_string()),
|
||||
);
|
||||
let dp_aware_prefill = WorkerFactory::create_dp_aware(
|
||||
"http://dp-prefill:8080".to_string(),
|
||||
1,
|
||||
@@ -1689,12 +1720,14 @@ mod tests {
|
||||
WorkerType::Prefill {
|
||||
bootstrap_port: None,
|
||||
},
|
||||
Some("test_api_key".to_string()),
|
||||
);
|
||||
let dp_aware_decode = WorkerFactory::create_dp_aware(
|
||||
"http://dp-decode:8080".to_string(),
|
||||
0,
|
||||
4,
|
||||
WorkerType::Decode,
|
||||
Some("test_api_key".to_string()),
|
||||
);
|
||||
|
||||
let workers: Vec<Box<dyn Worker>> = vec![
|
||||
|
||||
Reference in New Issue
Block a user