[Router]fix: fix get_load missing api_key (#10385)
This commit is contained in:
@@ -24,7 +24,8 @@ static WORKER_CLIENT: LazyLock<reqwest::Client> = LazyLock::new(|| {
|
||||
pub trait Worker: Send + Sync + fmt::Debug {
|
||||
/// Get the worker's URL
|
||||
fn url(&self) -> &str;
|
||||
|
||||
/// Get the worker's API key
|
||||
fn api_key(&self) -> &Option<String>;
|
||||
/// Get the worker's type (Regular, Prefill, or Decode)
|
||||
fn worker_type(&self) -> WorkerType;
|
||||
|
||||
@@ -323,6 +324,8 @@ pub struct WorkerMetadata {
|
||||
pub labels: std::collections::HashMap<String, String>,
|
||||
/// Health check configuration
|
||||
pub health_config: HealthConfig,
|
||||
/// API key
|
||||
pub api_key: Option<String>,
|
||||
}
|
||||
|
||||
/// Basic worker implementation
|
||||
@@ -379,6 +382,10 @@ impl Worker for BasicWorker {
|
||||
&self.metadata.url
|
||||
}
|
||||
|
||||
fn api_key(&self) -> &Option<String> {
|
||||
&self.metadata.api_key
|
||||
}
|
||||
|
||||
fn worker_type(&self) -> WorkerType {
|
||||
self.metadata.worker_type.clone()
|
||||
}
|
||||
@@ -548,6 +555,10 @@ impl Worker for DPAwareWorker {
|
||||
self.base_worker.url()
|
||||
}
|
||||
|
||||
fn api_key(&self) -> &Option<String> {
|
||||
self.base_worker.api_key()
|
||||
}
|
||||
|
||||
fn worker_type(&self) -> WorkerType {
|
||||
self.base_worker.worker_type()
|
||||
}
|
||||
@@ -650,19 +661,21 @@ impl WorkerFactory {
|
||||
dp_rank: usize,
|
||||
dp_size: usize,
|
||||
worker_type: WorkerType,
|
||||
api_key: Option<String>,
|
||||
) -> Box<dyn Worker> {
|
||||
Box::new(
|
||||
DPAwareWorkerBuilder::new(base_url, dp_rank, dp_size)
|
||||
.worker_type(worker_type)
|
||||
.build(),
|
||||
)
|
||||
let mut builder =
|
||||
DPAwareWorkerBuilder::new(base_url, dp_rank, dp_size).worker_type(worker_type);
|
||||
if let Some(api_key) = api_key {
|
||||
builder = builder.api_key(api_key);
|
||||
}
|
||||
Box::new(builder.build())
|
||||
}
|
||||
#[allow(dead_code)]
|
||||
/// Get DP size from a worker
|
||||
async fn get_worker_dp_size(url: &str, api_key: &Option<String>) -> WorkerResult<usize> {
|
||||
let mut req_builder = WORKER_CLIENT.get(format!("{}/get_server_info", url));
|
||||
|
||||
if let Some(key) = api_key {
|
||||
if let Some(key) = &api_key {
|
||||
req_builder = req_builder.bearer_auth(key);
|
||||
}
|
||||
|
||||
@@ -708,14 +721,18 @@ impl WorkerFactory {
|
||||
}
|
||||
|
||||
/// Convert a list of worker URLs to worker trait objects
|
||||
pub fn urls_to_workers(urls: Vec<String>) -> Vec<Box<dyn Worker>> {
|
||||
pub fn urls_to_workers(urls: Vec<String>, api_key: Option<String>) -> Vec<Box<dyn Worker>> {
|
||||
urls.into_iter()
|
||||
.map(|url| {
|
||||
Box::new(
|
||||
BasicWorkerBuilder::new(url)
|
||||
.worker_type(WorkerType::Regular)
|
||||
.build(),
|
||||
) as Box<dyn Worker>
|
||||
let worker_builder = BasicWorkerBuilder::new(url).worker_type(WorkerType::Regular);
|
||||
|
||||
let worker = if let Some(ref api_key) = api_key {
|
||||
worker_builder.api_key(api_key.clone()).build()
|
||||
} else {
|
||||
worker_builder.build()
|
||||
};
|
||||
|
||||
Box::new(worker) as Box<dyn Worker>
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
@@ -961,6 +978,7 @@ mod tests {
|
||||
use crate::core::BasicWorkerBuilder;
|
||||
let worker = BasicWorkerBuilder::new("http://test:8080")
|
||||
.worker_type(WorkerType::Regular)
|
||||
.api_key("test_api_key")
|
||||
.build();
|
||||
assert_eq!(worker.url(), "http://test:8080");
|
||||
assert_eq!(worker.worker_type(), WorkerType::Regular);
|
||||
@@ -998,6 +1016,7 @@ mod tests {
|
||||
let worker = BasicWorkerBuilder::new("http://test:8080")
|
||||
.worker_type(WorkerType::Regular)
|
||||
.health_config(custom_config.clone())
|
||||
.api_key("test_api_key")
|
||||
.build();
|
||||
|
||||
assert_eq!(worker.metadata().health_config.timeout_secs, 15);
|
||||
@@ -1011,6 +1030,7 @@ mod tests {
|
||||
use crate::core::BasicWorkerBuilder;
|
||||
let worker = BasicWorkerBuilder::new("http://worker1:8080")
|
||||
.worker_type(WorkerType::Regular)
|
||||
.api_key("test_api_key")
|
||||
.build();
|
||||
assert_eq!(worker.url(), "http://worker1:8080");
|
||||
}
|
||||
@@ -1020,6 +1040,7 @@ mod tests {
|
||||
use crate::core::BasicWorkerBuilder;
|
||||
let regular = BasicWorkerBuilder::new("http://test:8080")
|
||||
.worker_type(WorkerType::Regular)
|
||||
.api_key("test_api_key")
|
||||
.build();
|
||||
assert_eq!(regular.worker_type(), WorkerType::Regular);
|
||||
|
||||
@@ -1027,6 +1048,7 @@ mod tests {
|
||||
.worker_type(WorkerType::Prefill {
|
||||
bootstrap_port: Some(9090),
|
||||
})
|
||||
.api_key("test_api_key")
|
||||
.build();
|
||||
assert_eq!(
|
||||
prefill.worker_type(),
|
||||
@@ -1037,6 +1059,7 @@ mod tests {
|
||||
|
||||
let decode = BasicWorkerBuilder::new("http://test:8080")
|
||||
.worker_type(WorkerType::Decode)
|
||||
.api_key("test_api_key")
|
||||
.build();
|
||||
assert_eq!(decode.worker_type(), WorkerType::Decode);
|
||||
}
|
||||
@@ -1065,6 +1088,7 @@ mod tests {
|
||||
use crate::core::BasicWorkerBuilder;
|
||||
let worker = BasicWorkerBuilder::new("http://test:8080")
|
||||
.worker_type(WorkerType::Regular)
|
||||
.api_key("test_api_key")
|
||||
.build();
|
||||
|
||||
// Initial load is 0
|
||||
@@ -1350,7 +1374,7 @@ mod tests {
|
||||
fn test_urls_to_workers() {
|
||||
let urls = vec!["http://w1:8080".to_string(), "http://w2:8080".to_string()];
|
||||
|
||||
let workers = urls_to_workers(urls);
|
||||
let workers = urls_to_workers(urls, Some("test_api_key".to_string()));
|
||||
assert_eq!(workers.len(), 2);
|
||||
assert_eq!(workers[0].url(), "http://w1:8080");
|
||||
assert_eq!(workers[1].url(), "http://w2:8080");
|
||||
@@ -1547,6 +1571,7 @@ mod tests {
|
||||
1,
|
||||
4,
|
||||
WorkerType::Regular,
|
||||
Some("test_api_key".to_string()),
|
||||
);
|
||||
|
||||
assert_eq!(worker.url(), "http://worker1:8080@1");
|
||||
@@ -1565,6 +1590,7 @@ mod tests {
|
||||
WorkerType::Prefill {
|
||||
bootstrap_port: Some(8090),
|
||||
},
|
||||
Some("test_api_key".to_string()),
|
||||
);
|
||||
|
||||
assert_eq!(worker.url(), "http://worker1:8080@0");
|
||||
@@ -1680,8 +1706,13 @@ mod tests {
|
||||
.worker_type(WorkerType::Decode)
|
||||
.build(),
|
||||
);
|
||||
let dp_aware_regular =
|
||||
WorkerFactory::create_dp_aware("http://dp:8080".to_string(), 0, 2, WorkerType::Regular);
|
||||
let dp_aware_regular = WorkerFactory::create_dp_aware(
|
||||
"http://dp:8080".to_string(),
|
||||
0,
|
||||
2,
|
||||
WorkerType::Regular,
|
||||
Some("test_api_key".to_string()),
|
||||
);
|
||||
let dp_aware_prefill = WorkerFactory::create_dp_aware(
|
||||
"http://dp-prefill:8080".to_string(),
|
||||
1,
|
||||
@@ -1689,12 +1720,14 @@ mod tests {
|
||||
WorkerType::Prefill {
|
||||
bootstrap_port: None,
|
||||
},
|
||||
Some("test_api_key".to_string()),
|
||||
);
|
||||
let dp_aware_decode = WorkerFactory::create_dp_aware(
|
||||
"http://dp-decode:8080".to_string(),
|
||||
0,
|
||||
4,
|
||||
WorkerType::Decode,
|
||||
Some("test_api_key".to_string()),
|
||||
);
|
||||
|
||||
let workers: Vec<Box<dyn Worker>> = vec![
|
||||
|
||||
@@ -11,6 +11,7 @@ pub struct BasicWorkerBuilder {
|
||||
url: String,
|
||||
|
||||
// Optional fields with defaults
|
||||
api_key: Option<String>,
|
||||
worker_type: WorkerType,
|
||||
connection_mode: ConnectionMode,
|
||||
labels: HashMap<String, String>,
|
||||
@@ -24,6 +25,7 @@ impl BasicWorkerBuilder {
|
||||
pub fn new(url: impl Into<String>) -> Self {
|
||||
Self {
|
||||
url: url.into(),
|
||||
api_key: None,
|
||||
worker_type: WorkerType::Regular,
|
||||
connection_mode: ConnectionMode::Http,
|
||||
labels: HashMap::new(),
|
||||
@@ -37,6 +39,7 @@ impl BasicWorkerBuilder {
|
||||
pub fn new_with_type(url: impl Into<String>, worker_type: WorkerType) -> Self {
|
||||
Self {
|
||||
url: url.into(),
|
||||
api_key: None,
|
||||
worker_type,
|
||||
connection_mode: ConnectionMode::Http,
|
||||
labels: HashMap::new(),
|
||||
@@ -46,6 +49,12 @@ impl BasicWorkerBuilder {
|
||||
}
|
||||
}
|
||||
|
||||
/// Set the API key
|
||||
pub fn api_key(mut self, api_key: impl Into<String>) -> Self {
|
||||
self.api_key = Some(api_key.into());
|
||||
self
|
||||
}
|
||||
|
||||
/// Set the worker type (Regular, Prefill, or Decode)
|
||||
pub fn worker_type(mut self, worker_type: WorkerType) -> Self {
|
||||
self.worker_type = worker_type;
|
||||
@@ -98,6 +107,7 @@ impl BasicWorkerBuilder {
|
||||
|
||||
let metadata = WorkerMetadata {
|
||||
url: self.url.clone(),
|
||||
api_key: self.api_key,
|
||||
worker_type: self.worker_type,
|
||||
connection_mode: self.connection_mode,
|
||||
labels: self.labels,
|
||||
@@ -121,6 +131,7 @@ impl BasicWorkerBuilder {
|
||||
pub struct DPAwareWorkerBuilder {
|
||||
// Required fields
|
||||
base_url: String,
|
||||
api_key: Option<String>,
|
||||
dp_rank: usize,
|
||||
dp_size: usize,
|
||||
|
||||
@@ -138,6 +149,7 @@ impl DPAwareWorkerBuilder {
|
||||
pub fn new(base_url: impl Into<String>, dp_rank: usize, dp_size: usize) -> Self {
|
||||
Self {
|
||||
base_url: base_url.into(),
|
||||
api_key: None,
|
||||
dp_rank,
|
||||
dp_size,
|
||||
worker_type: WorkerType::Regular,
|
||||
@@ -158,6 +170,7 @@ impl DPAwareWorkerBuilder {
|
||||
) -> Self {
|
||||
Self {
|
||||
base_url: base_url.into(),
|
||||
api_key: None,
|
||||
dp_rank,
|
||||
dp_size,
|
||||
worker_type,
|
||||
@@ -169,6 +182,12 @@ impl DPAwareWorkerBuilder {
|
||||
}
|
||||
}
|
||||
|
||||
/// Set the API key
|
||||
pub fn api_key(mut self, api_key: impl Into<String>) -> Self {
|
||||
self.api_key = Some(api_key.into());
|
||||
self
|
||||
}
|
||||
|
||||
/// Set the worker type (Regular, Prefill, or Decode)
|
||||
pub fn worker_type(mut self, worker_type: WorkerType) -> Self {
|
||||
self.worker_type = worker_type;
|
||||
@@ -228,6 +247,10 @@ impl DPAwareWorkerBuilder {
|
||||
if let Some(client) = self.grpc_client {
|
||||
builder = builder.grpc_client(client);
|
||||
}
|
||||
// Add API key if provided
|
||||
if let Some(api_key) = self.api_key {
|
||||
builder = builder.api_key(api_key);
|
||||
}
|
||||
|
||||
let base_worker = builder.build();
|
||||
|
||||
@@ -382,6 +405,7 @@ mod tests {
|
||||
.connection_mode(ConnectionMode::Http)
|
||||
.labels(labels.clone())
|
||||
.health_config(health_config.clone())
|
||||
.api_key("test_api_key")
|
||||
.build();
|
||||
|
||||
assert_eq!(worker.url(), "http://localhost:8080@3");
|
||||
|
||||
@@ -256,6 +256,18 @@ impl WorkerRegistry {
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub fn get_all_urls_with_api_key(&self) -> Vec<(String, Option<String>)> {
|
||||
self.workers
|
||||
.iter()
|
||||
.map(|entry| {
|
||||
(
|
||||
entry.value().url().to_string(),
|
||||
entry.value().api_key().clone(),
|
||||
)
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Get all model IDs with workers
|
||||
pub fn get_models(&self) -> Vec<String> {
|
||||
self.model_workers
|
||||
@@ -442,6 +454,7 @@ mod tests {
|
||||
.worker_type(WorkerType::Regular)
|
||||
.labels(labels)
|
||||
.circuit_breaker_config(CircuitBreakerConfig::default())
|
||||
.api_key("test_api_key")
|
||||
.build(),
|
||||
);
|
||||
|
||||
@@ -477,6 +490,7 @@ mod tests {
|
||||
.worker_type(WorkerType::Regular)
|
||||
.labels(labels1)
|
||||
.circuit_breaker_config(CircuitBreakerConfig::default())
|
||||
.api_key("test_api_key")
|
||||
.build(),
|
||||
);
|
||||
|
||||
@@ -487,6 +501,7 @@ mod tests {
|
||||
.worker_type(WorkerType::Regular)
|
||||
.labels(labels2)
|
||||
.circuit_breaker_config(CircuitBreakerConfig::default())
|
||||
.api_key("test_api_key")
|
||||
.build(),
|
||||
);
|
||||
|
||||
@@ -497,6 +512,7 @@ mod tests {
|
||||
.worker_type(WorkerType::Regular)
|
||||
.labels(labels3)
|
||||
.circuit_breaker_config(CircuitBreakerConfig::default())
|
||||
.api_key("test_api_key")
|
||||
.build(),
|
||||
);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user