[Router]fix: fix get_load missing api_key (#10385)
This commit is contained in:
@@ -129,7 +129,9 @@ def test_dp_aware_worker_expansion_and_api_key(
|
||||
|
||||
# Attach worker; router should expand to dp_size logical workers
|
||||
r = requests.post(
|
||||
f"{router_url}/add_worker", params={"url": worker_url}, timeout=180
|
||||
f"{router_url}/add_worker",
|
||||
params={"url": worker_url, "api_key": api_key},
|
||||
timeout=180,
|
||||
)
|
||||
r.raise_for_status()
|
||||
|
||||
|
||||
@@ -139,7 +139,8 @@ def create_app(args: argparse.Namespace) -> FastAPI:
|
||||
)
|
||||
|
||||
@app.get("/get_load")
|
||||
async def get_load():
|
||||
async def get_load(request: Request):
|
||||
check_api_key(request)
|
||||
return JSONResponse({"load": _inflight})
|
||||
|
||||
def make_json_response(obj: dict, status_code: int = 200) -> JSONResponse:
|
||||
|
||||
@@ -24,7 +24,8 @@ static WORKER_CLIENT: LazyLock<reqwest::Client> = LazyLock::new(|| {
|
||||
pub trait Worker: Send + Sync + fmt::Debug {
|
||||
/// Get the worker's URL
|
||||
fn url(&self) -> &str;
|
||||
|
||||
/// Get the worker's API key
|
||||
fn api_key(&self) -> &Option<String>;
|
||||
/// Get the worker's type (Regular, Prefill, or Decode)
|
||||
fn worker_type(&self) -> WorkerType;
|
||||
|
||||
@@ -323,6 +324,8 @@ pub struct WorkerMetadata {
|
||||
pub labels: std::collections::HashMap<String, String>,
|
||||
/// Health check configuration
|
||||
pub health_config: HealthConfig,
|
||||
/// API key
|
||||
pub api_key: Option<String>,
|
||||
}
|
||||
|
||||
/// Basic worker implementation
|
||||
@@ -379,6 +382,10 @@ impl Worker for BasicWorker {
|
||||
&self.metadata.url
|
||||
}
|
||||
|
||||
fn api_key(&self) -> &Option<String> {
|
||||
&self.metadata.api_key
|
||||
}
|
||||
|
||||
fn worker_type(&self) -> WorkerType {
|
||||
self.metadata.worker_type.clone()
|
||||
}
|
||||
@@ -548,6 +555,10 @@ impl Worker for DPAwareWorker {
|
||||
self.base_worker.url()
|
||||
}
|
||||
|
||||
fn api_key(&self) -> &Option<String> {
|
||||
self.base_worker.api_key()
|
||||
}
|
||||
|
||||
fn worker_type(&self) -> WorkerType {
|
||||
self.base_worker.worker_type()
|
||||
}
|
||||
@@ -650,19 +661,21 @@ impl WorkerFactory {
|
||||
dp_rank: usize,
|
||||
dp_size: usize,
|
||||
worker_type: WorkerType,
|
||||
api_key: Option<String>,
|
||||
) -> Box<dyn Worker> {
|
||||
Box::new(
|
||||
DPAwareWorkerBuilder::new(base_url, dp_rank, dp_size)
|
||||
.worker_type(worker_type)
|
||||
.build(),
|
||||
)
|
||||
let mut builder =
|
||||
DPAwareWorkerBuilder::new(base_url, dp_rank, dp_size).worker_type(worker_type);
|
||||
if let Some(api_key) = api_key {
|
||||
builder = builder.api_key(api_key);
|
||||
}
|
||||
Box::new(builder.build())
|
||||
}
|
||||
#[allow(dead_code)]
|
||||
/// Get DP size from a worker
|
||||
async fn get_worker_dp_size(url: &str, api_key: &Option<String>) -> WorkerResult<usize> {
|
||||
let mut req_builder = WORKER_CLIENT.get(format!("{}/get_server_info", url));
|
||||
|
||||
if let Some(key) = api_key {
|
||||
if let Some(key) = &api_key {
|
||||
req_builder = req_builder.bearer_auth(key);
|
||||
}
|
||||
|
||||
@@ -708,14 +721,18 @@ impl WorkerFactory {
|
||||
}
|
||||
|
||||
/// Convert a list of worker URLs to worker trait objects
|
||||
pub fn urls_to_workers(urls: Vec<String>) -> Vec<Box<dyn Worker>> {
|
||||
pub fn urls_to_workers(urls: Vec<String>, api_key: Option<String>) -> Vec<Box<dyn Worker>> {
|
||||
urls.into_iter()
|
||||
.map(|url| {
|
||||
Box::new(
|
||||
BasicWorkerBuilder::new(url)
|
||||
.worker_type(WorkerType::Regular)
|
||||
.build(),
|
||||
) as Box<dyn Worker>
|
||||
let worker_builder = BasicWorkerBuilder::new(url).worker_type(WorkerType::Regular);
|
||||
|
||||
let worker = if let Some(ref api_key) = api_key {
|
||||
worker_builder.api_key(api_key.clone()).build()
|
||||
} else {
|
||||
worker_builder.build()
|
||||
};
|
||||
|
||||
Box::new(worker) as Box<dyn Worker>
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
@@ -961,6 +978,7 @@ mod tests {
|
||||
use crate::core::BasicWorkerBuilder;
|
||||
let worker = BasicWorkerBuilder::new("http://test:8080")
|
||||
.worker_type(WorkerType::Regular)
|
||||
.api_key("test_api_key")
|
||||
.build();
|
||||
assert_eq!(worker.url(), "http://test:8080");
|
||||
assert_eq!(worker.worker_type(), WorkerType::Regular);
|
||||
@@ -998,6 +1016,7 @@ mod tests {
|
||||
let worker = BasicWorkerBuilder::new("http://test:8080")
|
||||
.worker_type(WorkerType::Regular)
|
||||
.health_config(custom_config.clone())
|
||||
.api_key("test_api_key")
|
||||
.build();
|
||||
|
||||
assert_eq!(worker.metadata().health_config.timeout_secs, 15);
|
||||
@@ -1011,6 +1030,7 @@ mod tests {
|
||||
use crate::core::BasicWorkerBuilder;
|
||||
let worker = BasicWorkerBuilder::new("http://worker1:8080")
|
||||
.worker_type(WorkerType::Regular)
|
||||
.api_key("test_api_key")
|
||||
.build();
|
||||
assert_eq!(worker.url(), "http://worker1:8080");
|
||||
}
|
||||
@@ -1020,6 +1040,7 @@ mod tests {
|
||||
use crate::core::BasicWorkerBuilder;
|
||||
let regular = BasicWorkerBuilder::new("http://test:8080")
|
||||
.worker_type(WorkerType::Regular)
|
||||
.api_key("test_api_key")
|
||||
.build();
|
||||
assert_eq!(regular.worker_type(), WorkerType::Regular);
|
||||
|
||||
@@ -1027,6 +1048,7 @@ mod tests {
|
||||
.worker_type(WorkerType::Prefill {
|
||||
bootstrap_port: Some(9090),
|
||||
})
|
||||
.api_key("test_api_key")
|
||||
.build();
|
||||
assert_eq!(
|
||||
prefill.worker_type(),
|
||||
@@ -1037,6 +1059,7 @@ mod tests {
|
||||
|
||||
let decode = BasicWorkerBuilder::new("http://test:8080")
|
||||
.worker_type(WorkerType::Decode)
|
||||
.api_key("test_api_key")
|
||||
.build();
|
||||
assert_eq!(decode.worker_type(), WorkerType::Decode);
|
||||
}
|
||||
@@ -1065,6 +1088,7 @@ mod tests {
|
||||
use crate::core::BasicWorkerBuilder;
|
||||
let worker = BasicWorkerBuilder::new("http://test:8080")
|
||||
.worker_type(WorkerType::Regular)
|
||||
.api_key("test_api_key")
|
||||
.build();
|
||||
|
||||
// Initial load is 0
|
||||
@@ -1350,7 +1374,7 @@ mod tests {
|
||||
fn test_urls_to_workers() {
|
||||
let urls = vec!["http://w1:8080".to_string(), "http://w2:8080".to_string()];
|
||||
|
||||
let workers = urls_to_workers(urls);
|
||||
let workers = urls_to_workers(urls, Some("test_api_key".to_string()));
|
||||
assert_eq!(workers.len(), 2);
|
||||
assert_eq!(workers[0].url(), "http://w1:8080");
|
||||
assert_eq!(workers[1].url(), "http://w2:8080");
|
||||
@@ -1547,6 +1571,7 @@ mod tests {
|
||||
1,
|
||||
4,
|
||||
WorkerType::Regular,
|
||||
Some("test_api_key".to_string()),
|
||||
);
|
||||
|
||||
assert_eq!(worker.url(), "http://worker1:8080@1");
|
||||
@@ -1565,6 +1590,7 @@ mod tests {
|
||||
WorkerType::Prefill {
|
||||
bootstrap_port: Some(8090),
|
||||
},
|
||||
Some("test_api_key".to_string()),
|
||||
);
|
||||
|
||||
assert_eq!(worker.url(), "http://worker1:8080@0");
|
||||
@@ -1680,8 +1706,13 @@ mod tests {
|
||||
.worker_type(WorkerType::Decode)
|
||||
.build(),
|
||||
);
|
||||
let dp_aware_regular =
|
||||
WorkerFactory::create_dp_aware("http://dp:8080".to_string(), 0, 2, WorkerType::Regular);
|
||||
let dp_aware_regular = WorkerFactory::create_dp_aware(
|
||||
"http://dp:8080".to_string(),
|
||||
0,
|
||||
2,
|
||||
WorkerType::Regular,
|
||||
Some("test_api_key".to_string()),
|
||||
);
|
||||
let dp_aware_prefill = WorkerFactory::create_dp_aware(
|
||||
"http://dp-prefill:8080".to_string(),
|
||||
1,
|
||||
@@ -1689,12 +1720,14 @@ mod tests {
|
||||
WorkerType::Prefill {
|
||||
bootstrap_port: None,
|
||||
},
|
||||
Some("test_api_key".to_string()),
|
||||
);
|
||||
let dp_aware_decode = WorkerFactory::create_dp_aware(
|
||||
"http://dp-decode:8080".to_string(),
|
||||
0,
|
||||
4,
|
||||
WorkerType::Decode,
|
||||
Some("test_api_key".to_string()),
|
||||
);
|
||||
|
||||
let workers: Vec<Box<dyn Worker>> = vec![
|
||||
|
||||
@@ -11,6 +11,7 @@ pub struct BasicWorkerBuilder {
|
||||
url: String,
|
||||
|
||||
// Optional fields with defaults
|
||||
api_key: Option<String>,
|
||||
worker_type: WorkerType,
|
||||
connection_mode: ConnectionMode,
|
||||
labels: HashMap<String, String>,
|
||||
@@ -24,6 +25,7 @@ impl BasicWorkerBuilder {
|
||||
pub fn new(url: impl Into<String>) -> Self {
|
||||
Self {
|
||||
url: url.into(),
|
||||
api_key: None,
|
||||
worker_type: WorkerType::Regular,
|
||||
connection_mode: ConnectionMode::Http,
|
||||
labels: HashMap::new(),
|
||||
@@ -37,6 +39,7 @@ impl BasicWorkerBuilder {
|
||||
pub fn new_with_type(url: impl Into<String>, worker_type: WorkerType) -> Self {
|
||||
Self {
|
||||
url: url.into(),
|
||||
api_key: None,
|
||||
worker_type,
|
||||
connection_mode: ConnectionMode::Http,
|
||||
labels: HashMap::new(),
|
||||
@@ -46,6 +49,12 @@ impl BasicWorkerBuilder {
|
||||
}
|
||||
}
|
||||
|
||||
/// Set the API key
|
||||
pub fn api_key(mut self, api_key: impl Into<String>) -> Self {
|
||||
self.api_key = Some(api_key.into());
|
||||
self
|
||||
}
|
||||
|
||||
/// Set the worker type (Regular, Prefill, or Decode)
|
||||
pub fn worker_type(mut self, worker_type: WorkerType) -> Self {
|
||||
self.worker_type = worker_type;
|
||||
@@ -98,6 +107,7 @@ impl BasicWorkerBuilder {
|
||||
|
||||
let metadata = WorkerMetadata {
|
||||
url: self.url.clone(),
|
||||
api_key: self.api_key,
|
||||
worker_type: self.worker_type,
|
||||
connection_mode: self.connection_mode,
|
||||
labels: self.labels,
|
||||
@@ -121,6 +131,7 @@ impl BasicWorkerBuilder {
|
||||
pub struct DPAwareWorkerBuilder {
|
||||
// Required fields
|
||||
base_url: String,
|
||||
api_key: Option<String>,
|
||||
dp_rank: usize,
|
||||
dp_size: usize,
|
||||
|
||||
@@ -138,6 +149,7 @@ impl DPAwareWorkerBuilder {
|
||||
pub fn new(base_url: impl Into<String>, dp_rank: usize, dp_size: usize) -> Self {
|
||||
Self {
|
||||
base_url: base_url.into(),
|
||||
api_key: None,
|
||||
dp_rank,
|
||||
dp_size,
|
||||
worker_type: WorkerType::Regular,
|
||||
@@ -158,6 +170,7 @@ impl DPAwareWorkerBuilder {
|
||||
) -> Self {
|
||||
Self {
|
||||
base_url: base_url.into(),
|
||||
api_key: None,
|
||||
dp_rank,
|
||||
dp_size,
|
||||
worker_type,
|
||||
@@ -169,6 +182,12 @@ impl DPAwareWorkerBuilder {
|
||||
}
|
||||
}
|
||||
|
||||
/// Set the API key
|
||||
pub fn api_key(mut self, api_key: impl Into<String>) -> Self {
|
||||
self.api_key = Some(api_key.into());
|
||||
self
|
||||
}
|
||||
|
||||
/// Set the worker type (Regular, Prefill, or Decode)
|
||||
pub fn worker_type(mut self, worker_type: WorkerType) -> Self {
|
||||
self.worker_type = worker_type;
|
||||
@@ -228,6 +247,10 @@ impl DPAwareWorkerBuilder {
|
||||
if let Some(client) = self.grpc_client {
|
||||
builder = builder.grpc_client(client);
|
||||
}
|
||||
// Add API key if provided
|
||||
if let Some(api_key) = self.api_key {
|
||||
builder = builder.api_key(api_key);
|
||||
}
|
||||
|
||||
let base_worker = builder.build();
|
||||
|
||||
@@ -382,6 +405,7 @@ mod tests {
|
||||
.connection_mode(ConnectionMode::Http)
|
||||
.labels(labels.clone())
|
||||
.health_config(health_config.clone())
|
||||
.api_key("test_api_key")
|
||||
.build();
|
||||
|
||||
assert_eq!(worker.url(), "http://localhost:8080@3");
|
||||
|
||||
@@ -256,6 +256,18 @@ impl WorkerRegistry {
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub fn get_all_urls_with_api_key(&self) -> Vec<(String, Option<String>)> {
|
||||
self.workers
|
||||
.iter()
|
||||
.map(|entry| {
|
||||
(
|
||||
entry.value().url().to_string(),
|
||||
entry.value().api_key().clone(),
|
||||
)
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Get all model IDs with workers
|
||||
pub fn get_models(&self) -> Vec<String> {
|
||||
self.model_workers
|
||||
@@ -442,6 +454,7 @@ mod tests {
|
||||
.worker_type(WorkerType::Regular)
|
||||
.labels(labels)
|
||||
.circuit_breaker_config(CircuitBreakerConfig::default())
|
||||
.api_key("test_api_key")
|
||||
.build(),
|
||||
);
|
||||
|
||||
@@ -477,6 +490,7 @@ mod tests {
|
||||
.worker_type(WorkerType::Regular)
|
||||
.labels(labels1)
|
||||
.circuit_breaker_config(CircuitBreakerConfig::default())
|
||||
.api_key("test_api_key")
|
||||
.build(),
|
||||
);
|
||||
|
||||
@@ -487,6 +501,7 @@ mod tests {
|
||||
.worker_type(WorkerType::Regular)
|
||||
.labels(labels2)
|
||||
.circuit_breaker_config(CircuitBreakerConfig::default())
|
||||
.api_key("test_api_key")
|
||||
.build(),
|
||||
);
|
||||
|
||||
@@ -497,6 +512,7 @@ mod tests {
|
||||
.worker_type(WorkerType::Regular)
|
||||
.labels(labels3)
|
||||
.circuit_breaker_config(CircuitBreakerConfig::default())
|
||||
.api_key("test_api_key")
|
||||
.build(),
|
||||
);
|
||||
|
||||
|
||||
@@ -465,11 +465,13 @@ mod tests {
|
||||
Arc::new(
|
||||
BasicWorkerBuilder::new("http://w1:8000")
|
||||
.worker_type(WorkerType::Regular)
|
||||
.api_key("test_api_key")
|
||||
.build(),
|
||||
),
|
||||
Arc::new(
|
||||
BasicWorkerBuilder::new("http://w2:8000")
|
||||
.worker_type(WorkerType::Regular)
|
||||
.api_key("test_api_key")
|
||||
.build(),
|
||||
),
|
||||
];
|
||||
|
||||
@@ -129,16 +129,19 @@ mod tests {
|
||||
Arc::new(
|
||||
BasicWorkerBuilder::new("http://w1:8000")
|
||||
.worker_type(WorkerType::Regular)
|
||||
.api_key("test_api_key")
|
||||
.build(),
|
||||
),
|
||||
Arc::new(
|
||||
BasicWorkerBuilder::new("http://w2:8000")
|
||||
.worker_type(WorkerType::Regular)
|
||||
.api_key("test_api_key2")
|
||||
.build(),
|
||||
),
|
||||
Arc::new(
|
||||
BasicWorkerBuilder::new("http://w3:8000")
|
||||
.worker_type(WorkerType::Regular)
|
||||
.api_key("test_api_key")
|
||||
.build(),
|
||||
),
|
||||
];
|
||||
|
||||
@@ -11,6 +11,10 @@ pub struct WorkerConfigRequest {
|
||||
/// Worker URL (required)
|
||||
pub url: String,
|
||||
|
||||
/// Worker API key (optional)
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub api_key: Option<String>,
|
||||
|
||||
/// Model ID (optional, will query from server if not provided)
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub model_id: Option<String>,
|
||||
|
||||
@@ -353,7 +353,11 @@ impl RouterTrait for GrpcPDRouter {
|
||||
|
||||
#[async_trait]
|
||||
impl WorkerManagement for GrpcPDRouter {
|
||||
async fn add_worker(&self, _worker_url: &str) -> Result<String, String> {
|
||||
async fn add_worker(
|
||||
&self,
|
||||
_worker_url: &str,
|
||||
_api_key: &Option<String>,
|
||||
) -> Result<String, String> {
|
||||
Err("Not implemented".to_string())
|
||||
}
|
||||
|
||||
|
||||
@@ -282,7 +282,11 @@ impl RouterTrait for GrpcRouter {
|
||||
|
||||
#[async_trait]
|
||||
impl WorkerManagement for GrpcRouter {
|
||||
async fn add_worker(&self, _worker_url: &str) -> Result<String, String> {
|
||||
async fn add_worker(
|
||||
&self,
|
||||
_worker_url: &str,
|
||||
_api_key: &Option<String>,
|
||||
) -> Result<String, String> {
|
||||
Err("Not implemented".to_string())
|
||||
}
|
||||
|
||||
|
||||
@@ -67,7 +67,11 @@ impl OpenAIRouter {
|
||||
|
||||
#[async_trait]
|
||||
impl super::super::WorkerManagement for OpenAIRouter {
|
||||
async fn add_worker(&self, _worker_url: &str) -> Result<String, String> {
|
||||
async fn add_worker(
|
||||
&self,
|
||||
_worker_url: &str,
|
||||
_api_key: &Option<String>,
|
||||
) -> Result<String, String> {
|
||||
Err("Cannot add workers to OpenAI router".to_string())
|
||||
}
|
||||
|
||||
|
||||
@@ -46,6 +46,8 @@ pub struct PDRouter {
|
||||
pub prefill_client: Client,
|
||||
pub retry_config: RetryConfig,
|
||||
pub circuit_breaker_config: CircuitBreakerConfig,
|
||||
pub api_key: Option<String>,
|
||||
|
||||
// Channel for sending prefill responses to background workers for draining
|
||||
prefill_drain_tx: mpsc::Sender<reqwest::Response>,
|
||||
}
|
||||
@@ -113,21 +115,25 @@ impl PDRouter {
|
||||
(results, errors)
|
||||
}
|
||||
|
||||
fn _get_worker_url_and_key(&self, w: &Arc<dyn Worker>) -> (String, Option<String>) {
|
||||
(w.url().to_string(), w.api_key().clone())
|
||||
}
|
||||
|
||||
// Helper to get prefill worker URLs
|
||||
fn get_prefill_worker_urls(&self) -> Vec<String> {
|
||||
fn get_prefill_worker_urls_with_api_key(&self) -> Vec<(String, Option<String>)> {
|
||||
self.worker_registry
|
||||
.get_prefill_workers()
|
||||
.iter()
|
||||
.map(|w| w.url().to_string())
|
||||
.map(|w| self._get_worker_url_and_key(w))
|
||||
.collect()
|
||||
}
|
||||
|
||||
// Helper to get decode worker URLs
|
||||
fn get_decode_worker_urls(&self) -> Vec<String> {
|
||||
fn get_decode_worker_urls_with_api_key(&self) -> Vec<(String, Option<String>)> {
|
||||
self.worker_registry
|
||||
.get_decode_workers()
|
||||
.iter()
|
||||
.map(|w| w.url().to_string())
|
||||
.map(|w| self._get_worker_url_and_key(w))
|
||||
.collect()
|
||||
}
|
||||
|
||||
@@ -208,6 +214,7 @@ impl PDRouter {
|
||||
pub async fn add_prefill_server(
|
||||
&self,
|
||||
url: String,
|
||||
api_key: Option<String>,
|
||||
bootstrap_port: Option<u16>,
|
||||
) -> Result<String, PDRouterError> {
|
||||
// Wait for the new server to be healthy
|
||||
@@ -220,10 +227,15 @@ impl PDRouter {
|
||||
|
||||
// Create Worker for the new prefill server with circuit breaker configuration
|
||||
// TODO: In IGW mode, fetch model_id from worker's /get_model_info endpoint
|
||||
let worker = BasicWorkerBuilder::new(url.clone())
|
||||
let worker_builder = BasicWorkerBuilder::new(url.clone())
|
||||
.worker_type(WorkerType::Prefill { bootstrap_port })
|
||||
.circuit_breaker_config(self.circuit_breaker_config.clone())
|
||||
.build();
|
||||
.circuit_breaker_config(self.circuit_breaker_config.clone());
|
||||
|
||||
let worker = if let Some(api_key) = api_key {
|
||||
worker_builder.api_key(api_key).build()
|
||||
} else {
|
||||
worker_builder.build()
|
||||
};
|
||||
|
||||
let worker_arc: Arc<dyn Worker> = Arc::new(worker);
|
||||
|
||||
@@ -243,7 +255,11 @@ impl PDRouter {
|
||||
Ok(format!("Successfully added prefill server: {}", url))
|
||||
}
|
||||
|
||||
pub async fn add_decode_server(&self, url: String) -> Result<String, PDRouterError> {
|
||||
pub async fn add_decode_server(
|
||||
&self,
|
||||
url: String,
|
||||
api_key: Option<String>,
|
||||
) -> Result<String, PDRouterError> {
|
||||
// Wait for the new server to be healthy
|
||||
self.wait_for_server_health(&url).await?;
|
||||
|
||||
@@ -254,10 +270,15 @@ impl PDRouter {
|
||||
|
||||
// Create Worker for the new decode server with circuit breaker configuration
|
||||
// TODO: In IGW mode, fetch model_id from worker's /get_model_info endpoint
|
||||
let worker = BasicWorkerBuilder::new(url.clone())
|
||||
let worker_builder = BasicWorkerBuilder::new(url.clone())
|
||||
.worker_type(WorkerType::Decode)
|
||||
.circuit_breaker_config(self.circuit_breaker_config.clone())
|
||||
.build();
|
||||
.circuit_breaker_config(self.circuit_breaker_config.clone());
|
||||
|
||||
let worker = if let Some(api_key) = api_key {
|
||||
worker_builder.api_key(api_key).build()
|
||||
} else {
|
||||
worker_builder.build()
|
||||
};
|
||||
|
||||
let worker_arc: Arc<dyn Worker> = Arc::new(worker);
|
||||
|
||||
@@ -366,6 +387,12 @@ impl PDRouter {
|
||||
.chain(decode_workers.iter())
|
||||
.map(|w| w.url().to_string())
|
||||
.collect();
|
||||
// Get all worker API keys for monitoring
|
||||
let all_api_keys: Vec<Option<String>> = prefill_workers
|
||||
.iter()
|
||||
.chain(decode_workers.iter())
|
||||
.map(|w| w.api_key().clone())
|
||||
.collect();
|
||||
|
||||
// Convert config CircuitBreakerConfig to core CircuitBreakerConfig
|
||||
let circuit_breaker_config = ctx.router_config.effective_circuit_breaker_config();
|
||||
@@ -387,6 +414,7 @@ impl PDRouter {
|
||||
let load_monitor_handle =
|
||||
if prefill_policy.name() == "power_of_two" || decode_policy.name() == "power_of_two" {
|
||||
let monitor_urls = all_urls.clone();
|
||||
let monitor_api_keys = all_api_keys.clone();
|
||||
let monitor_interval = ctx.router_config.worker_startup_check_interval_secs;
|
||||
let monitor_client = ctx.client.clone();
|
||||
let prefill_policy_clone = Arc::clone(&prefill_policy);
|
||||
@@ -395,6 +423,7 @@ impl PDRouter {
|
||||
Some(Arc::new(tokio::spawn(async move {
|
||||
Self::monitor_worker_loads_with_client(
|
||||
monitor_urls,
|
||||
monitor_api_keys,
|
||||
tx,
|
||||
monitor_interval,
|
||||
monitor_client,
|
||||
@@ -500,6 +529,7 @@ impl PDRouter {
|
||||
prefill_drain_tx,
|
||||
retry_config: ctx.router_config.effective_retry_config(),
|
||||
circuit_breaker_config: core_cb_config,
|
||||
api_key: ctx.router_config.api_key.clone(),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1150,6 +1180,7 @@ impl PDRouter {
|
||||
// Background task to monitor worker loads with shared client
|
||||
async fn monitor_worker_loads_with_client(
|
||||
worker_urls: Vec<String>,
|
||||
worker_api_keys: Vec<Option<String>>,
|
||||
tx: tokio::sync::watch::Sender<HashMap<String, isize>>,
|
||||
interval_secs: u64,
|
||||
client: Client,
|
||||
@@ -1161,11 +1192,13 @@ impl PDRouter {
|
||||
|
||||
let futures: Vec<_> = worker_urls
|
||||
.iter()
|
||||
.map(|url| {
|
||||
.zip(worker_api_keys.iter())
|
||||
.map(|(url, api_key)| {
|
||||
let client = client.clone();
|
||||
let url = url.clone();
|
||||
let api_key = api_key.clone();
|
||||
async move {
|
||||
let load = get_worker_load(&client, &url).await.unwrap_or(0);
|
||||
let load = get_worker_load(&client, &url, &api_key).await.unwrap_or(0);
|
||||
(url, load)
|
||||
}
|
||||
})
|
||||
@@ -1515,8 +1548,16 @@ impl PDRouter {
|
||||
|
||||
// Helper functions
|
||||
|
||||
async fn get_worker_load(client: &Client, worker_url: &str) -> Option<isize> {
|
||||
match client.get(format!("{}/get_load", worker_url)).send().await {
|
||||
async fn get_worker_load(
|
||||
client: &Client,
|
||||
worker_url: &str,
|
||||
api_key: &Option<String>,
|
||||
) -> Option<isize> {
|
||||
let mut req_builder = client.get(format!("{}/get_load", worker_url));
|
||||
if let Some(key) = api_key {
|
||||
req_builder = req_builder.bearer_auth(key);
|
||||
}
|
||||
match req_builder.send().await {
|
||||
Ok(res) if res.status().is_success() => match res.bytes().await {
|
||||
Ok(bytes) => match serde_json::from_slice::<Value>(&bytes) {
|
||||
Ok(data) => data
|
||||
@@ -1550,7 +1591,11 @@ async fn get_worker_load(client: &Client, worker_url: &str) -> Option<isize> {
|
||||
|
||||
#[async_trait]
|
||||
impl WorkerManagement for PDRouter {
|
||||
async fn add_worker(&self, _worker_url: &str) -> Result<String, String> {
|
||||
async fn add_worker(
|
||||
&self,
|
||||
_worker_url: &str,
|
||||
_api_key: &Option<String>,
|
||||
) -> Result<String, String> {
|
||||
// For PD router, we don't support adding workers via this generic method
|
||||
Err(
|
||||
"PD router requires specific add_prefill_server or add_decode_server methods"
|
||||
@@ -1956,9 +2001,9 @@ impl RouterTrait for PDRouter {
|
||||
let mut errors = Vec::new();
|
||||
|
||||
// Process prefill workers
|
||||
let prefill_urls = self.get_prefill_worker_urls();
|
||||
for worker_url in prefill_urls {
|
||||
match get_worker_load(&self.client, &worker_url).await {
|
||||
let prefill_urls_with_key = self.get_prefill_worker_urls_with_api_key();
|
||||
for (worker_url, api_key) in prefill_urls_with_key {
|
||||
match get_worker_load(&self.client, &worker_url, &api_key).await {
|
||||
Some(load) => {
|
||||
loads.insert(format!("prefill_{}", worker_url), load);
|
||||
}
|
||||
@@ -1969,9 +2014,9 @@ impl RouterTrait for PDRouter {
|
||||
}
|
||||
|
||||
// Process decode workers
|
||||
let decode_urls = self.get_decode_worker_urls();
|
||||
for worker_url in decode_urls {
|
||||
match get_worker_load(&self.client, &worker_url).await {
|
||||
let decode_urls_with_key = self.get_decode_worker_urls_with_api_key();
|
||||
for (worker_url, api_key) in decode_urls_with_key {
|
||||
match get_worker_load(&self.client, &worker_url, &api_key).await {
|
||||
Some(load) => {
|
||||
loads.insert(format!("decode_{}", worker_url), load);
|
||||
}
|
||||
@@ -2069,12 +2114,14 @@ mod tests {
|
||||
prefill_drain_tx: mpsc::channel(100).0,
|
||||
retry_config: RetryConfig::default(),
|
||||
circuit_breaker_config: CircuitBreakerConfig::default(),
|
||||
api_key: Some("test_api_key".to_string()),
|
||||
}
|
||||
}
|
||||
|
||||
fn create_test_worker(url: String, worker_type: WorkerType, healthy: bool) -> Box<dyn Worker> {
|
||||
let worker = BasicWorkerBuilder::new(url)
|
||||
.worker_type(worker_type)
|
||||
.api_key("test_api_key")
|
||||
.build();
|
||||
worker.set_healthy(healthy);
|
||||
Box::new(worker)
|
||||
|
||||
@@ -38,6 +38,7 @@ pub struct Router {
|
||||
worker_startup_timeout_secs: u64,
|
||||
worker_startup_check_interval_secs: u64,
|
||||
dp_aware: bool,
|
||||
#[allow(dead_code)]
|
||||
api_key: Option<String>,
|
||||
retry_config: RetryConfig,
|
||||
circuit_breaker_config: CircuitBreakerConfig,
|
||||
@@ -71,7 +72,6 @@ impl Router {
|
||||
};
|
||||
|
||||
// Cache-aware policies are initialized in WorkerInitializer
|
||||
|
||||
// Setup load monitoring for PowerOfTwo policy
|
||||
let (tx, rx) = tokio::sync::watch::channel(HashMap::new());
|
||||
let worker_loads = Arc::new(rx);
|
||||
@@ -82,6 +82,14 @@ impl Router {
|
||||
// Check if default policy is power_of_two for load monitoring
|
||||
let load_monitor_handle = if default_policy.name() == "power_of_two" {
|
||||
let monitor_urls = worker_urls.clone();
|
||||
let monitor_api_keys = monitor_urls
|
||||
.iter()
|
||||
.map(|url| {
|
||||
ctx.worker_registry
|
||||
.get_by_url(url)
|
||||
.and_then(|w| w.api_key().clone())
|
||||
})
|
||||
.collect::<Vec<Option<String>>>();
|
||||
let monitor_interval = ctx.router_config.worker_startup_check_interval_secs;
|
||||
let policy_clone = default_policy.clone();
|
||||
let client_clone = ctx.client.clone();
|
||||
@@ -89,6 +97,7 @@ impl Router {
|
||||
Some(Arc::new(tokio::spawn(async move {
|
||||
Self::monitor_worker_loads(
|
||||
monitor_urls,
|
||||
monitor_api_keys,
|
||||
tx,
|
||||
monitor_interval,
|
||||
policy_clone,
|
||||
@@ -912,7 +921,11 @@ impl Router {
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn add_worker(&self, worker_url: &str) -> Result<String, String> {
|
||||
pub async fn add_worker(
|
||||
&self,
|
||||
worker_url: &str,
|
||||
api_key: &Option<String>,
|
||||
) -> Result<String, String> {
|
||||
let start_time = std::time::Instant::now();
|
||||
let client = reqwest::Client::builder()
|
||||
.timeout(Duration::from_secs(self.worker_startup_timeout_secs))
|
||||
@@ -938,7 +951,7 @@ impl Router {
|
||||
// Need to contact the worker to extract the dp_size,
|
||||
// and add them as multiple workers
|
||||
let url_vec = vec![String::from(worker_url)];
|
||||
let dp_url_vec = Self::get_dp_aware_workers(&url_vec, &self.api_key)
|
||||
let dp_url_vec = Self::get_dp_aware_workers(&url_vec, api_key)
|
||||
.map_err(|e| format!("Failed to get dp-aware workers: {}", e))?;
|
||||
let mut worker_added: bool = false;
|
||||
for dp_url in &dp_url_vec {
|
||||
@@ -948,10 +961,18 @@ impl Router {
|
||||
}
|
||||
info!("Added worker: {}", dp_url);
|
||||
// TODO: In IGW mode, fetch model_id from worker's /get_model_info endpoint
|
||||
let new_worker = BasicWorkerBuilder::new(dp_url.to_string())
|
||||
.worker_type(WorkerType::Regular)
|
||||
.circuit_breaker_config(self.circuit_breaker_config.clone())
|
||||
.build();
|
||||
let new_worker_builder =
|
||||
BasicWorkerBuilder::new(dp_url.to_string())
|
||||
.worker_type(WorkerType::Regular)
|
||||
.circuit_breaker_config(
|
||||
self.circuit_breaker_config.clone(),
|
||||
);
|
||||
|
||||
let new_worker = if let Some(api_key) = api_key {
|
||||
new_worker_builder.api_key(api_key).build()
|
||||
} else {
|
||||
new_worker_builder.build()
|
||||
};
|
||||
|
||||
let worker_arc = Arc::new(new_worker);
|
||||
self.worker_registry.register(worker_arc.clone());
|
||||
@@ -978,10 +999,16 @@ impl Router {
|
||||
info!("Added worker: {}", worker_url);
|
||||
|
||||
// TODO: In IGW mode, fetch model_id from worker's /get_model_info endpoint
|
||||
let new_worker = BasicWorkerBuilder::new(worker_url.to_string())
|
||||
.worker_type(WorkerType::Regular)
|
||||
.circuit_breaker_config(self.circuit_breaker_config.clone())
|
||||
.build();
|
||||
let new_worker_builder =
|
||||
BasicWorkerBuilder::new(worker_url.to_string())
|
||||
.worker_type(WorkerType::Regular)
|
||||
.circuit_breaker_config(self.circuit_breaker_config.clone());
|
||||
|
||||
let new_worker = if let Some(api_key) = api_key {
|
||||
new_worker_builder.api_key(api_key).build()
|
||||
} else {
|
||||
new_worker_builder.build()
|
||||
};
|
||||
|
||||
let worker_arc = Arc::new(new_worker);
|
||||
self.worker_registry.register(worker_arc.clone());
|
||||
@@ -1094,7 +1121,7 @@ impl Router {
|
||||
}
|
||||
}
|
||||
|
||||
async fn get_worker_load(&self, worker_url: &str) -> Option<isize> {
|
||||
async fn get_worker_load(&self, worker_url: &str, api_key: &Option<String>) -> Option<isize> {
|
||||
let worker_url = if self.dp_aware {
|
||||
// Need to extract the URL from "http://host:port@dp_rank"
|
||||
let (worker_url_prefix, _dp_rank) = match Self::extract_dp_rank(worker_url) {
|
||||
@@ -1109,12 +1136,12 @@ impl Router {
|
||||
worker_url
|
||||
};
|
||||
|
||||
match self
|
||||
.client
|
||||
.get(format!("{}/get_load", worker_url))
|
||||
.send()
|
||||
.await
|
||||
{
|
||||
let mut req_builder = self.client.get(format!("{}/get_load", worker_url));
|
||||
if let Some(key) = api_key {
|
||||
req_builder = req_builder.bearer_auth(key);
|
||||
}
|
||||
|
||||
match req_builder.send().await {
|
||||
Ok(res) if res.status().is_success() => match res.bytes().await {
|
||||
Ok(bytes) => match serde_json::from_slice::<serde_json::Value>(&bytes) {
|
||||
Ok(data) => data
|
||||
@@ -1149,6 +1176,7 @@ impl Router {
|
||||
// Background task to monitor worker loads
|
||||
async fn monitor_worker_loads(
|
||||
worker_urls: Vec<String>,
|
||||
worker_api_keys: Vec<Option<String>>,
|
||||
tx: tokio::sync::watch::Sender<HashMap<String, isize>>,
|
||||
interval_secs: u64,
|
||||
policy: Arc<dyn LoadBalancingPolicy>,
|
||||
@@ -1160,8 +1188,8 @@ impl Router {
|
||||
interval.tick().await;
|
||||
|
||||
let mut loads = HashMap::new();
|
||||
for url in &worker_urls {
|
||||
if let Some(load) = Self::get_worker_load_static(&client, url).await {
|
||||
for (url, api_key) in worker_urls.iter().zip(worker_api_keys.iter()) {
|
||||
if let Some(load) = Self::get_worker_load_static(&client, url, api_key).await {
|
||||
loads.insert(url.clone(), load);
|
||||
}
|
||||
}
|
||||
@@ -1179,7 +1207,11 @@ impl Router {
|
||||
}
|
||||
|
||||
// Static version of get_worker_load for use in monitoring task
|
||||
async fn get_worker_load_static(client: &reqwest::Client, worker_url: &str) -> Option<isize> {
|
||||
async fn get_worker_load_static(
|
||||
client: &reqwest::Client,
|
||||
worker_url: &str,
|
||||
api_key: &Option<String>,
|
||||
) -> Option<isize> {
|
||||
let worker_url = if worker_url.contains("@") {
|
||||
// Need to extract the URL from "http://host:port@dp_rank"
|
||||
let (worker_url_prefix, _dp_rank) = match Self::extract_dp_rank(worker_url) {
|
||||
@@ -1194,7 +1226,11 @@ impl Router {
|
||||
worker_url
|
||||
};
|
||||
|
||||
match client.get(format!("{}/get_load", worker_url)).send().await {
|
||||
let mut req_builder = client.get(format!("{}/get_load", worker_url));
|
||||
if let Some(key) = api_key {
|
||||
req_builder = req_builder.bearer_auth(key);
|
||||
}
|
||||
match req_builder.send().await {
|
||||
Ok(res) if res.status().is_success() => match res.bytes().await {
|
||||
Ok(bytes) => match serde_json::from_slice::<serde_json::Value>(&bytes) {
|
||||
Ok(data) => data
|
||||
@@ -1250,8 +1286,12 @@ use async_trait::async_trait;
|
||||
|
||||
#[async_trait]
|
||||
impl WorkerManagement for Router {
|
||||
async fn add_worker(&self, worker_url: &str) -> Result<String, String> {
|
||||
Router::add_worker(self, worker_url).await
|
||||
async fn add_worker(
|
||||
&self,
|
||||
worker_url: &str,
|
||||
api_key: &Option<String>,
|
||||
) -> Result<String, String> {
|
||||
Router::add_worker(self, worker_url, api_key).await
|
||||
}
|
||||
|
||||
fn remove_worker(&self, worker_url: &str) {
|
||||
@@ -1457,12 +1497,12 @@ impl RouterTrait for Router {
|
||||
}
|
||||
|
||||
async fn get_worker_loads(&self) -> Response {
|
||||
let urls = self.get_worker_urls();
|
||||
let urls_with_key = self.worker_registry.get_all_urls_with_api_key();
|
||||
let mut loads = Vec::new();
|
||||
|
||||
// Get loads from all workers
|
||||
for url in &urls {
|
||||
let load = self.get_worker_load(url).await.unwrap_or(-1);
|
||||
for (url, api_key) in &urls_with_key {
|
||||
let load = self.get_worker_load(url, api_key).await.unwrap_or(-1);
|
||||
loads.push(serde_json::json!({
|
||||
"worker": url,
|
||||
"load": load
|
||||
@@ -1521,9 +1561,11 @@ mod tests {
|
||||
// Register test workers
|
||||
let worker1 = BasicWorkerBuilder::new("http://worker1:8080")
|
||||
.worker_type(WorkerType::Regular)
|
||||
.api_key("test_api_key")
|
||||
.build();
|
||||
let worker2 = BasicWorkerBuilder::new("http://worker2:8080")
|
||||
.worker_type(WorkerType::Regular)
|
||||
.api_key("test_api_key")
|
||||
.build();
|
||||
worker_registry.register(Arc::new(worker1));
|
||||
worker_registry.register(Arc::new(worker2));
|
||||
|
||||
@@ -33,7 +33,11 @@ pub use http::{openai_router, pd_router, pd_types, router};
|
||||
#[async_trait]
|
||||
pub trait WorkerManagement: Send + Sync {
|
||||
/// Add a worker to the router
|
||||
async fn add_worker(&self, worker_url: &str) -> Result<String, String>;
|
||||
async fn add_worker(
|
||||
&self,
|
||||
worker_url: &str,
|
||||
api_key: &Option<String>,
|
||||
) -> Result<String, String>;
|
||||
|
||||
/// Remove a worker from the router
|
||||
fn remove_worker(&self, worker_url: &str);
|
||||
|
||||
@@ -161,7 +161,7 @@ impl RouterManager {
|
||||
let model_id = if let Some(model_id) = config.model_id {
|
||||
model_id
|
||||
} else {
|
||||
match self.query_server_info(&config.url).await {
|
||||
match self.query_server_info(&config.url, &config.api_key).await {
|
||||
Ok(info) => {
|
||||
// Extract model_id from server info
|
||||
info.model_id
|
||||
@@ -208,29 +208,44 @@ impl RouterManager {
|
||||
}
|
||||
|
||||
let worker = match config.worker_type.as_deref() {
|
||||
Some("prefill") => Box::new(
|
||||
BasicWorkerBuilder::new(config.url.clone())
|
||||
Some("prefill") => {
|
||||
let mut builder = BasicWorkerBuilder::new(config.url.clone())
|
||||
.worker_type(WorkerType::Prefill {
|
||||
bootstrap_port: config.bootstrap_port,
|
||||
})
|
||||
.labels(labels.clone())
|
||||
.circuit_breaker_config(CircuitBreakerConfig::default())
|
||||
.build(),
|
||||
) as Box<dyn Worker>,
|
||||
Some("decode") => Box::new(
|
||||
BasicWorkerBuilder::new(config.url.clone())
|
||||
.circuit_breaker_config(CircuitBreakerConfig::default());
|
||||
|
||||
if let Some(api_key) = config.api_key.clone() {
|
||||
builder = builder.api_key(api_key);
|
||||
}
|
||||
|
||||
Box::new(builder.build()) as Box<dyn Worker>
|
||||
}
|
||||
Some("decode") => {
|
||||
let mut builder = BasicWorkerBuilder::new(config.url.clone())
|
||||
.worker_type(WorkerType::Decode)
|
||||
.labels(labels.clone())
|
||||
.circuit_breaker_config(CircuitBreakerConfig::default())
|
||||
.build(),
|
||||
) as Box<dyn Worker>,
|
||||
_ => Box::new(
|
||||
BasicWorkerBuilder::new(config.url.clone())
|
||||
.circuit_breaker_config(CircuitBreakerConfig::default());
|
||||
|
||||
if let Some(api_key) = config.api_key.clone() {
|
||||
builder = builder.api_key(api_key);
|
||||
}
|
||||
|
||||
Box::new(builder.build()) as Box<dyn Worker>
|
||||
}
|
||||
_ => {
|
||||
let mut builder = BasicWorkerBuilder::new(config.url.clone())
|
||||
.worker_type(WorkerType::Regular)
|
||||
.labels(labels.clone())
|
||||
.circuit_breaker_config(CircuitBreakerConfig::default())
|
||||
.build(),
|
||||
) as Box<dyn Worker>,
|
||||
.circuit_breaker_config(CircuitBreakerConfig::default());
|
||||
|
||||
if let Some(api_key) = config.api_key.clone() {
|
||||
builder = builder.api_key(api_key);
|
||||
}
|
||||
|
||||
Box::new(builder.build()) as Box<dyn Worker>
|
||||
}
|
||||
};
|
||||
|
||||
// Register worker
|
||||
@@ -346,10 +361,18 @@ impl RouterManager {
|
||||
}
|
||||
|
||||
/// Query server info from a worker URL
|
||||
async fn query_server_info(&self, url: &str) -> Result<ServerInfo, String> {
|
||||
async fn query_server_info(
|
||||
&self,
|
||||
url: &str,
|
||||
api_key: &Option<String>,
|
||||
) -> Result<ServerInfo, String> {
|
||||
let info_url = format!("{}/get_server_info", url.trim_end_matches('/'));
|
||||
|
||||
match self.client.get(&info_url).send().await {
|
||||
let mut req_builder = self.client.get(&info_url);
|
||||
if let Some(key) = api_key {
|
||||
req_builder = req_builder.bearer_auth(key);
|
||||
}
|
||||
match req_builder.send().await {
|
||||
Ok(response) => {
|
||||
if response.status().is_success() {
|
||||
response
|
||||
@@ -477,10 +500,15 @@ impl RouterManager {
|
||||
#[async_trait]
|
||||
impl WorkerManagement for RouterManager {
|
||||
/// Add a worker - in multi-router mode, this adds to the registry
|
||||
async fn add_worker(&self, worker_url: &str) -> Result<String, String> {
|
||||
async fn add_worker(
|
||||
&self,
|
||||
worker_url: &str,
|
||||
api_key: &Option<String>,
|
||||
) -> Result<String, String> {
|
||||
// Create a basic worker config request
|
||||
let config = WorkerConfigRequest {
|
||||
url: worker_url.to_string(),
|
||||
api_key: api_key.clone(),
|
||||
model_id: None,
|
||||
worker_type: None,
|
||||
priority: None,
|
||||
|
||||
@@ -27,8 +27,12 @@ impl WorkerInitializer {
|
||||
|
||||
match &config.mode {
|
||||
RoutingMode::Regular { worker_urls } => {
|
||||
// use router's api_key, repeat for each worker
|
||||
let worker_api_keys: Vec<Option<String>> =
|
||||
worker_urls.iter().map(|_| config.api_key.clone()).collect();
|
||||
Self::create_regular_workers(
|
||||
worker_urls,
|
||||
&worker_api_keys,
|
||||
&config.connection_mode,
|
||||
config,
|
||||
worker_registry,
|
||||
@@ -41,8 +45,16 @@ impl WorkerInitializer {
|
||||
decode_urls,
|
||||
..
|
||||
} => {
|
||||
// use router's api_key, repeat for each prefill/decode worker
|
||||
let prefill_api_keys: Vec<Option<String>> = prefill_urls
|
||||
.iter()
|
||||
.map(|_| config.api_key.clone())
|
||||
.collect();
|
||||
let decode_api_keys: Vec<Option<String>> =
|
||||
decode_urls.iter().map(|_| config.api_key.clone()).collect();
|
||||
Self::create_prefill_workers(
|
||||
prefill_urls,
|
||||
&prefill_api_keys,
|
||||
&config.connection_mode,
|
||||
config,
|
||||
worker_registry,
|
||||
@@ -51,6 +63,7 @@ impl WorkerInitializer {
|
||||
.await?;
|
||||
Self::create_decode_workers(
|
||||
decode_urls,
|
||||
&decode_api_keys,
|
||||
&config.connection_mode,
|
||||
config,
|
||||
worker_registry,
|
||||
@@ -79,6 +92,7 @@ impl WorkerInitializer {
|
||||
/// Create regular workers for standard routing mode
|
||||
async fn create_regular_workers(
|
||||
urls: &[String],
|
||||
api_keys: &[Option<String>],
|
||||
config_connection_mode: &ConfigConnectionMode,
|
||||
config: &RouterConfig,
|
||||
registry: &Arc<WorkerRegistry>,
|
||||
@@ -109,14 +123,18 @@ impl WorkerInitializer {
|
||||
|
||||
let mut registered_workers: HashMap<String, Vec<Arc<dyn Worker>>> = HashMap::new();
|
||||
|
||||
for url in urls {
|
||||
for (url, api_key) in urls.iter().zip(api_keys.iter()) {
|
||||
// TODO: Add DP-aware support when we have dp_rank/dp_size info
|
||||
let worker = BasicWorkerBuilder::new(url.clone())
|
||||
let worker_builder = BasicWorkerBuilder::new(url.clone())
|
||||
.worker_type(WorkerType::Regular)
|
||||
.connection_mode(connection_mode.clone())
|
||||
.circuit_breaker_config(core_cb_config.clone())
|
||||
.health_config(health_config.clone())
|
||||
.build();
|
||||
.health_config(health_config.clone());
|
||||
let worker = if let Some(api_key) = api_key.clone() {
|
||||
worker_builder.api_key(api_key).build()
|
||||
} else {
|
||||
worker_builder.build()
|
||||
};
|
||||
|
||||
let worker_arc = Arc::new(worker) as Arc<dyn Worker>;
|
||||
let model_id = worker_arc.model_id();
|
||||
@@ -148,6 +166,7 @@ impl WorkerInitializer {
|
||||
/// Create prefill workers for disaggregated routing mode
|
||||
async fn create_prefill_workers(
|
||||
prefill_entries: &[(String, Option<u16>)],
|
||||
api_keys: &[Option<String>],
|
||||
config_connection_mode: &ConfigConnectionMode,
|
||||
config: &RouterConfig,
|
||||
registry: &Arc<WorkerRegistry>,
|
||||
@@ -181,16 +200,20 @@ impl WorkerInitializer {
|
||||
|
||||
let mut registered_workers: HashMap<String, Vec<Arc<dyn Worker>>> = HashMap::new();
|
||||
|
||||
for (url, bootstrap_port) in prefill_entries {
|
||||
for ((url, bootstrap_port), api_key) in prefill_entries.iter().zip(api_keys.iter()) {
|
||||
// TODO: Add DP-aware support when we have dp_rank/dp_size info
|
||||
let worker = BasicWorkerBuilder::new(url.clone())
|
||||
let worker_builder = BasicWorkerBuilder::new(url.clone())
|
||||
.worker_type(WorkerType::Prefill {
|
||||
bootstrap_port: *bootstrap_port,
|
||||
})
|
||||
.connection_mode(connection_mode.clone())
|
||||
.circuit_breaker_config(core_cb_config.clone())
|
||||
.health_config(health_config.clone())
|
||||
.build();
|
||||
.health_config(health_config.clone());
|
||||
let worker = if let Some(api_key) = api_key.clone() {
|
||||
worker_builder.api_key(api_key).build()
|
||||
} else {
|
||||
worker_builder.build()
|
||||
};
|
||||
|
||||
let worker_arc = Arc::new(worker) as Arc<dyn Worker>;
|
||||
let model_id = worker_arc.model_id();
|
||||
@@ -227,6 +250,7 @@ impl WorkerInitializer {
|
||||
/// Create decode workers for disaggregated routing mode
|
||||
async fn create_decode_workers(
|
||||
urls: &[String],
|
||||
api_keys: &[Option<String>],
|
||||
config_connection_mode: &ConfigConnectionMode,
|
||||
config: &RouterConfig,
|
||||
registry: &Arc<WorkerRegistry>,
|
||||
@@ -257,14 +281,18 @@ impl WorkerInitializer {
|
||||
|
||||
let mut registered_workers: HashMap<String, Vec<Arc<dyn Worker>>> = HashMap::new();
|
||||
|
||||
for url in urls {
|
||||
for (url, api_key) in urls.iter().zip(api_keys.iter()) {
|
||||
// TODO: Add DP-aware support when we have dp_rank/dp_size info
|
||||
let worker = BasicWorkerBuilder::new(url.clone())
|
||||
let worker_builder = BasicWorkerBuilder::new(url.clone())
|
||||
.worker_type(WorkerType::Decode)
|
||||
.connection_mode(connection_mode.clone())
|
||||
.circuit_breaker_config(core_cb_config.clone())
|
||||
.health_config(health_config.clone())
|
||||
.build();
|
||||
.health_config(health_config.clone());
|
||||
let worker = if let Some(api_key) = api_key.clone() {
|
||||
worker_builder.api_key(api_key).build()
|
||||
} else {
|
||||
worker_builder.build()
|
||||
};
|
||||
|
||||
let worker_arc = Arc::new(worker) as Arc<dyn Worker>;
|
||||
let model_id = worker_arc.model_id();
|
||||
|
||||
@@ -282,15 +282,16 @@ async fn v1_responses_list_input_items(
|
||||
// ---------- Worker management endpoints (Legacy) ----------
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct UrlQuery {
|
||||
struct AddWorkerQuery {
|
||||
url: String,
|
||||
api_key: Option<String>,
|
||||
}
|
||||
|
||||
async fn add_worker(
|
||||
State(state): State<Arc<AppState>>,
|
||||
Query(UrlQuery { url }): Query<UrlQuery>,
|
||||
Query(AddWorkerQuery { url, api_key }): Query<AddWorkerQuery>,
|
||||
) -> Response {
|
||||
match state.router.add_worker(&url).await {
|
||||
match state.router.add_worker(&url, &api_key).await {
|
||||
Ok(message) => (StatusCode::OK, message).into_response(),
|
||||
Err(error) => (StatusCode::BAD_REQUEST, error).into_response(),
|
||||
}
|
||||
@@ -303,7 +304,7 @@ async fn list_workers(State(state): State<Arc<AppState>>) -> Response {
|
||||
|
||||
async fn remove_worker(
|
||||
State(state): State<Arc<AppState>>,
|
||||
Query(UrlQuery { url }): Query<UrlQuery>,
|
||||
Query(AddWorkerQuery { url, .. }): Query<AddWorkerQuery>,
|
||||
) -> Response {
|
||||
state.router.remove_worker(&url);
|
||||
(
|
||||
@@ -337,7 +338,7 @@ async fn create_worker(
|
||||
}
|
||||
} else {
|
||||
// In single router mode, use the router's add_worker with basic config
|
||||
match state.router.add_worker(&config.url).await {
|
||||
match state.router.add_worker(&config.url, &config.api_key).await {
|
||||
Ok(message) => {
|
||||
let response = WorkerApiResponse {
|
||||
success: true,
|
||||
|
||||
@@ -389,16 +389,20 @@ async fn handle_pod_event(
|
||||
if let Some(pd_router) = router.as_any().downcast_ref::<PDRouter>() {
|
||||
match &pod_info.pod_type {
|
||||
Some(PodType::Prefill) => pd_router
|
||||
.add_prefill_server(worker_url.clone(), pod_info.bootstrap_port)
|
||||
.add_prefill_server(
|
||||
worker_url.clone(),
|
||||
pd_router.api_key.clone(),
|
||||
pod_info.bootstrap_port,
|
||||
)
|
||||
.await
|
||||
.map_err(|e| e.to_string()),
|
||||
Some(PodType::Decode) => pd_router
|
||||
.add_decode_server(worker_url.clone())
|
||||
.add_decode_server(worker_url.clone(), pd_router.api_key.clone())
|
||||
.await
|
||||
.map_err(|e| e.to_string()),
|
||||
Some(PodType::Regular) | None => {
|
||||
// Fall back to regular add_worker for regular pods
|
||||
router.add_worker(&worker_url).await
|
||||
router.add_worker(&worker_url, &pd_router.api_key).await
|
||||
}
|
||||
}
|
||||
} else {
|
||||
@@ -406,7 +410,8 @@ async fn handle_pod_event(
|
||||
}
|
||||
} else {
|
||||
// Regular mode or no pod type specified
|
||||
router.add_worker(&worker_url).await
|
||||
// In pod, no need api key
|
||||
router.add_worker(&worker_url, &None).await
|
||||
};
|
||||
|
||||
match result {
|
||||
|
||||
@@ -18,6 +18,7 @@ fn test_backward_compatibility_with_empty_model_id() {
|
||||
// Create workers with empty model_id (simulating existing routers)
|
||||
let worker1 = BasicWorkerBuilder::new("http://worker1:8080")
|
||||
.worker_type(WorkerType::Regular)
|
||||
.api_key("test_api_key")
|
||||
.build();
|
||||
// No model_id label - should default to "unknown"
|
||||
|
||||
@@ -25,6 +26,7 @@ fn test_backward_compatibility_with_empty_model_id() {
|
||||
labels2.insert("model_id".to_string(), "unknown".to_string());
|
||||
let worker2 = BasicWorkerBuilder::new("http://worker2:8080")
|
||||
.worker_type(WorkerType::Regular)
|
||||
.api_key("test_api_key")
|
||||
.labels(labels2)
|
||||
.build();
|
||||
|
||||
@@ -59,6 +61,7 @@ fn test_mixed_model_ids() {
|
||||
// Create workers with different model_id scenarios
|
||||
let worker1 = BasicWorkerBuilder::new("http://worker1:8080")
|
||||
.worker_type(WorkerType::Regular)
|
||||
.api_key("test_api_key")
|
||||
.build();
|
||||
// No model_id label - defaults to "unknown" which goes to "default" tree
|
||||
|
||||
@@ -67,6 +70,7 @@ fn test_mixed_model_ids() {
|
||||
let worker2 = BasicWorkerBuilder::new("http://worker2:8080")
|
||||
.worker_type(WorkerType::Regular)
|
||||
.labels(labels2)
|
||||
.api_key("test_api_key")
|
||||
.build();
|
||||
|
||||
let mut labels3 = HashMap::new();
|
||||
@@ -123,10 +127,12 @@ fn test_remove_worker_by_url_backward_compat() {
|
||||
let worker1 = BasicWorkerBuilder::new("http://worker1:8080")
|
||||
.worker_type(WorkerType::Regular)
|
||||
.labels(labels1)
|
||||
.api_key("test_api_key")
|
||||
.build();
|
||||
|
||||
let worker2 = BasicWorkerBuilder::new("http://worker2:8080")
|
||||
.worker_type(WorkerType::Regular)
|
||||
.api_key("test_api_key")
|
||||
.build();
|
||||
// No model_id label - defaults to "unknown"
|
||||
|
||||
|
||||
@@ -41,6 +41,7 @@ async fn test_policy_registry_with_router_manager() {
|
||||
let _worker1_config = WorkerConfigRequest {
|
||||
url: "http://worker1:8000".to_string(),
|
||||
model_id: Some("llama-3".to_string()),
|
||||
api_key: Some("test_api_key".to_string()),
|
||||
worker_type: None,
|
||||
priority: None,
|
||||
cost: None,
|
||||
@@ -66,6 +67,7 @@ async fn test_policy_registry_with_router_manager() {
|
||||
let _worker2_config = WorkerConfigRequest {
|
||||
url: "http://worker2:8000".to_string(),
|
||||
model_id: Some("llama-3".to_string()),
|
||||
api_key: Some("test_api_key".to_string()),
|
||||
worker_type: None,
|
||||
priority: None,
|
||||
cost: None,
|
||||
@@ -86,6 +88,7 @@ async fn test_policy_registry_with_router_manager() {
|
||||
let _worker3_config = WorkerConfigRequest {
|
||||
url: "http://worker3:8000".to_string(),
|
||||
model_id: Some("gpt-4".to_string()),
|
||||
api_key: Some("test_api_key".to_string()),
|
||||
worker_type: None,
|
||||
priority: None,
|
||||
cost: None,
|
||||
|
||||
@@ -54,6 +54,7 @@ mod test_pd_routing {
|
||||
.worker_type(WorkerType::Prefill {
|
||||
bootstrap_port: Some(9000),
|
||||
})
|
||||
.api_key("test_api_key")
|
||||
.build(),
|
||||
);
|
||||
assert_eq!(prefill_worker.url(), "http://prefill:8080");
|
||||
@@ -68,6 +69,7 @@ mod test_pd_routing {
|
||||
let decode_worker: Box<dyn Worker> = Box::new(
|
||||
BasicWorkerBuilder::new("http://decode:8080")
|
||||
.worker_type(WorkerType::Decode)
|
||||
.api_key("test_api_key")
|
||||
.build(),
|
||||
);
|
||||
assert_eq!(decode_worker.url(), "http://decode:8080");
|
||||
@@ -80,6 +82,7 @@ mod test_pd_routing {
|
||||
let regular_worker: Box<dyn Worker> = Box::new(
|
||||
BasicWorkerBuilder::new("http://regular:8080")
|
||||
.worker_type(WorkerType::Regular)
|
||||
.api_key("test_api_key")
|
||||
.build(),
|
||||
);
|
||||
assert_eq!(regular_worker.url(), "http://regular:8080");
|
||||
@@ -297,6 +300,7 @@ mod test_pd_routing {
|
||||
.worker_type(WorkerType::Prefill {
|
||||
bootstrap_port: Some(9000),
|
||||
})
|
||||
.api_key("test_api_key")
|
||||
.build(),
|
||||
);
|
||||
|
||||
@@ -700,6 +704,7 @@ mod test_pd_routing {
|
||||
.worker_type(WorkerType::Prefill {
|
||||
bootstrap_port: Some(9000),
|
||||
})
|
||||
.api_key("test_api_key")
|
||||
.build(),
|
||||
);
|
||||
|
||||
@@ -836,6 +841,7 @@ mod test_pd_routing {
|
||||
.worker_type(WorkerType::Prefill {
|
||||
bootstrap_port: Some(9000),
|
||||
})
|
||||
.api_key("test_api_key")
|
||||
.build(),
|
||||
);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user