[Router]fix: fix get_load missing api_key (#10385)

This commit is contained in:
Jimmy
2025-09-22 03:28:38 +08:00
committed by GitHub
parent 12d6cf18f0
commit 56321e9fc2
21 changed files with 378 additions and 111 deletions

View File

@@ -24,7 +24,8 @@ static WORKER_CLIENT: LazyLock<reqwest::Client> = LazyLock::new(|| {
pub trait Worker: Send + Sync + fmt::Debug {
/// Get the worker's URL
fn url(&self) -> &str;
/// Get the worker's API key
fn api_key(&self) -> &Option<String>;
/// Get the worker's type (Regular, Prefill, or Decode)
fn worker_type(&self) -> WorkerType;
@@ -323,6 +324,8 @@ pub struct WorkerMetadata {
pub labels: std::collections::HashMap<String, String>,
/// Health check configuration
pub health_config: HealthConfig,
/// API key
pub api_key: Option<String>,
}
/// Basic worker implementation
@@ -379,6 +382,10 @@ impl Worker for BasicWorker {
&self.metadata.url
}
fn api_key(&self) -> &Option<String> {
&self.metadata.api_key
}
fn worker_type(&self) -> WorkerType {
self.metadata.worker_type.clone()
}
@@ -548,6 +555,10 @@ impl Worker for DPAwareWorker {
self.base_worker.url()
}
fn api_key(&self) -> &Option<String> {
self.base_worker.api_key()
}
fn worker_type(&self) -> WorkerType {
self.base_worker.worker_type()
}
@@ -650,19 +661,21 @@ impl WorkerFactory {
dp_rank: usize,
dp_size: usize,
worker_type: WorkerType,
api_key: Option<String>,
) -> Box<dyn Worker> {
Box::new(
DPAwareWorkerBuilder::new(base_url, dp_rank, dp_size)
.worker_type(worker_type)
.build(),
)
let mut builder =
DPAwareWorkerBuilder::new(base_url, dp_rank, dp_size).worker_type(worker_type);
if let Some(api_key) = api_key {
builder = builder.api_key(api_key);
}
Box::new(builder.build())
}
#[allow(dead_code)]
/// Get DP size from a worker
async fn get_worker_dp_size(url: &str, api_key: &Option<String>) -> WorkerResult<usize> {
let mut req_builder = WORKER_CLIENT.get(format!("{}/get_server_info", url));
if let Some(key) = api_key {
if let Some(key) = &api_key {
req_builder = req_builder.bearer_auth(key);
}
@@ -708,14 +721,18 @@ impl WorkerFactory {
}
/// Convert a list of worker URLs to worker trait objects
pub fn urls_to_workers(urls: Vec<String>) -> Vec<Box<dyn Worker>> {
pub fn urls_to_workers(urls: Vec<String>, api_key: Option<String>) -> Vec<Box<dyn Worker>> {
urls.into_iter()
.map(|url| {
Box::new(
BasicWorkerBuilder::new(url)
.worker_type(WorkerType::Regular)
.build(),
) as Box<dyn Worker>
let worker_builder = BasicWorkerBuilder::new(url).worker_type(WorkerType::Regular);
let worker = if let Some(ref api_key) = api_key {
worker_builder.api_key(api_key.clone()).build()
} else {
worker_builder.build()
};
Box::new(worker) as Box<dyn Worker>
})
.collect()
}
@@ -961,6 +978,7 @@ mod tests {
use crate::core::BasicWorkerBuilder;
let worker = BasicWorkerBuilder::new("http://test:8080")
.worker_type(WorkerType::Regular)
.api_key("test_api_key")
.build();
assert_eq!(worker.url(), "http://test:8080");
assert_eq!(worker.worker_type(), WorkerType::Regular);
@@ -998,6 +1016,7 @@ mod tests {
let worker = BasicWorkerBuilder::new("http://test:8080")
.worker_type(WorkerType::Regular)
.health_config(custom_config.clone())
.api_key("test_api_key")
.build();
assert_eq!(worker.metadata().health_config.timeout_secs, 15);
@@ -1011,6 +1030,7 @@ mod tests {
use crate::core::BasicWorkerBuilder;
let worker = BasicWorkerBuilder::new("http://worker1:8080")
.worker_type(WorkerType::Regular)
.api_key("test_api_key")
.build();
assert_eq!(worker.url(), "http://worker1:8080");
}
@@ -1020,6 +1040,7 @@ mod tests {
use crate::core::BasicWorkerBuilder;
let regular = BasicWorkerBuilder::new("http://test:8080")
.worker_type(WorkerType::Regular)
.api_key("test_api_key")
.build();
assert_eq!(regular.worker_type(), WorkerType::Regular);
@@ -1027,6 +1048,7 @@ mod tests {
.worker_type(WorkerType::Prefill {
bootstrap_port: Some(9090),
})
.api_key("test_api_key")
.build();
assert_eq!(
prefill.worker_type(),
@@ -1037,6 +1059,7 @@ mod tests {
let decode = BasicWorkerBuilder::new("http://test:8080")
.worker_type(WorkerType::Decode)
.api_key("test_api_key")
.build();
assert_eq!(decode.worker_type(), WorkerType::Decode);
}
@@ -1065,6 +1088,7 @@ mod tests {
use crate::core::BasicWorkerBuilder;
let worker = BasicWorkerBuilder::new("http://test:8080")
.worker_type(WorkerType::Regular)
.api_key("test_api_key")
.build();
// Initial load is 0
@@ -1350,7 +1374,7 @@ mod tests {
fn test_urls_to_workers() {
let urls = vec!["http://w1:8080".to_string(), "http://w2:8080".to_string()];
let workers = urls_to_workers(urls);
let workers = urls_to_workers(urls, Some("test_api_key".to_string()));
assert_eq!(workers.len(), 2);
assert_eq!(workers[0].url(), "http://w1:8080");
assert_eq!(workers[1].url(), "http://w2:8080");
@@ -1547,6 +1571,7 @@ mod tests {
1,
4,
WorkerType::Regular,
Some("test_api_key".to_string()),
);
assert_eq!(worker.url(), "http://worker1:8080@1");
@@ -1565,6 +1590,7 @@ mod tests {
WorkerType::Prefill {
bootstrap_port: Some(8090),
},
Some("test_api_key".to_string()),
);
assert_eq!(worker.url(), "http://worker1:8080@0");
@@ -1680,8 +1706,13 @@ mod tests {
.worker_type(WorkerType::Decode)
.build(),
);
let dp_aware_regular =
WorkerFactory::create_dp_aware("http://dp:8080".to_string(), 0, 2, WorkerType::Regular);
let dp_aware_regular = WorkerFactory::create_dp_aware(
"http://dp:8080".to_string(),
0,
2,
WorkerType::Regular,
Some("test_api_key".to_string()),
);
let dp_aware_prefill = WorkerFactory::create_dp_aware(
"http://dp-prefill:8080".to_string(),
1,
@@ -1689,12 +1720,14 @@ mod tests {
WorkerType::Prefill {
bootstrap_port: None,
},
Some("test_api_key".to_string()),
);
let dp_aware_decode = WorkerFactory::create_dp_aware(
"http://dp-decode:8080".to_string(),
0,
4,
WorkerType::Decode,
Some("test_api_key".to_string()),
);
let workers: Vec<Box<dyn Worker>> = vec![

View File

@@ -11,6 +11,7 @@ pub struct BasicWorkerBuilder {
url: String,
// Optional fields with defaults
api_key: Option<String>,
worker_type: WorkerType,
connection_mode: ConnectionMode,
labels: HashMap<String, String>,
@@ -24,6 +25,7 @@ impl BasicWorkerBuilder {
pub fn new(url: impl Into<String>) -> Self {
Self {
url: url.into(),
api_key: None,
worker_type: WorkerType::Regular,
connection_mode: ConnectionMode::Http,
labels: HashMap::new(),
@@ -37,6 +39,7 @@ impl BasicWorkerBuilder {
pub fn new_with_type(url: impl Into<String>, worker_type: WorkerType) -> Self {
Self {
url: url.into(),
api_key: None,
worker_type,
connection_mode: ConnectionMode::Http,
labels: HashMap::new(),
@@ -46,6 +49,12 @@ impl BasicWorkerBuilder {
}
}
/// Set the API key
pub fn api_key(mut self, api_key: impl Into<String>) -> Self {
self.api_key = Some(api_key.into());
self
}
/// Set the worker type (Regular, Prefill, or Decode)
pub fn worker_type(mut self, worker_type: WorkerType) -> Self {
self.worker_type = worker_type;
@@ -98,6 +107,7 @@ impl BasicWorkerBuilder {
let metadata = WorkerMetadata {
url: self.url.clone(),
api_key: self.api_key,
worker_type: self.worker_type,
connection_mode: self.connection_mode,
labels: self.labels,
@@ -121,6 +131,7 @@ impl BasicWorkerBuilder {
pub struct DPAwareWorkerBuilder {
// Required fields
base_url: String,
api_key: Option<String>,
dp_rank: usize,
dp_size: usize,
@@ -138,6 +149,7 @@ impl DPAwareWorkerBuilder {
pub fn new(base_url: impl Into<String>, dp_rank: usize, dp_size: usize) -> Self {
Self {
base_url: base_url.into(),
api_key: None,
dp_rank,
dp_size,
worker_type: WorkerType::Regular,
@@ -158,6 +170,7 @@ impl DPAwareWorkerBuilder {
) -> Self {
Self {
base_url: base_url.into(),
api_key: None,
dp_rank,
dp_size,
worker_type,
@@ -169,6 +182,12 @@ impl DPAwareWorkerBuilder {
}
}
/// Set the API key
pub fn api_key(mut self, api_key: impl Into<String>) -> Self {
self.api_key = Some(api_key.into());
self
}
/// Set the worker type (Regular, Prefill, or Decode)
pub fn worker_type(mut self, worker_type: WorkerType) -> Self {
self.worker_type = worker_type;
@@ -228,6 +247,10 @@ impl DPAwareWorkerBuilder {
if let Some(client) = self.grpc_client {
builder = builder.grpc_client(client);
}
// Add API key if provided
if let Some(api_key) = self.api_key {
builder = builder.api_key(api_key);
}
let base_worker = builder.build();
@@ -382,6 +405,7 @@ mod tests {
.connection_mode(ConnectionMode::Http)
.labels(labels.clone())
.health_config(health_config.clone())
.api_key("test_api_key")
.build();
assert_eq!(worker.url(), "http://localhost:8080@3");

View File

@@ -256,6 +256,18 @@ impl WorkerRegistry {
.collect()
}
pub fn get_all_urls_with_api_key(&self) -> Vec<(String, Option<String>)> {
self.workers
.iter()
.map(|entry| {
(
entry.value().url().to_string(),
entry.value().api_key().clone(),
)
})
.collect()
}
/// Get all model IDs with workers
pub fn get_models(&self) -> Vec<String> {
self.model_workers
@@ -442,6 +454,7 @@ mod tests {
.worker_type(WorkerType::Regular)
.labels(labels)
.circuit_breaker_config(CircuitBreakerConfig::default())
.api_key("test_api_key")
.build(),
);
@@ -477,6 +490,7 @@ mod tests {
.worker_type(WorkerType::Regular)
.labels(labels1)
.circuit_breaker_config(CircuitBreakerConfig::default())
.api_key("test_api_key")
.build(),
);
@@ -487,6 +501,7 @@ mod tests {
.worker_type(WorkerType::Regular)
.labels(labels2)
.circuit_breaker_config(CircuitBreakerConfig::default())
.api_key("test_api_key")
.build(),
);
@@ -497,6 +512,7 @@ mod tests {
.worker_type(WorkerType::Regular)
.labels(labels3)
.circuit_breaker_config(CircuitBreakerConfig::default())
.api_key("test_api_key")
.build(),
);