[Router]fix: fix get_load missing api_key (#10385)
This commit is contained in:
@@ -353,7 +353,11 @@ impl RouterTrait for GrpcPDRouter {
|
||||
|
||||
#[async_trait]
|
||||
impl WorkerManagement for GrpcPDRouter {
|
||||
async fn add_worker(&self, _worker_url: &str) -> Result<String, String> {
|
||||
async fn add_worker(
|
||||
&self,
|
||||
_worker_url: &str,
|
||||
_api_key: &Option<String>,
|
||||
) -> Result<String, String> {
|
||||
Err("Not implemented".to_string())
|
||||
}
|
||||
|
||||
|
||||
@@ -282,7 +282,11 @@ impl RouterTrait for GrpcRouter {
|
||||
|
||||
#[async_trait]
|
||||
impl WorkerManagement for GrpcRouter {
|
||||
async fn add_worker(&self, _worker_url: &str) -> Result<String, String> {
|
||||
async fn add_worker(
|
||||
&self,
|
||||
_worker_url: &str,
|
||||
_api_key: &Option<String>,
|
||||
) -> Result<String, String> {
|
||||
Err("Not implemented".to_string())
|
||||
}
|
||||
|
||||
|
||||
@@ -67,7 +67,11 @@ impl OpenAIRouter {
|
||||
|
||||
#[async_trait]
|
||||
impl super::super::WorkerManagement for OpenAIRouter {
|
||||
async fn add_worker(&self, _worker_url: &str) -> Result<String, String> {
|
||||
async fn add_worker(
|
||||
&self,
|
||||
_worker_url: &str,
|
||||
_api_key: &Option<String>,
|
||||
) -> Result<String, String> {
|
||||
Err("Cannot add workers to OpenAI router".to_string())
|
||||
}
|
||||
|
||||
|
||||
@@ -46,6 +46,8 @@ pub struct PDRouter {
|
||||
pub prefill_client: Client,
|
||||
pub retry_config: RetryConfig,
|
||||
pub circuit_breaker_config: CircuitBreakerConfig,
|
||||
pub api_key: Option<String>,
|
||||
|
||||
// Channel for sending prefill responses to background workers for draining
|
||||
prefill_drain_tx: mpsc::Sender<reqwest::Response>,
|
||||
}
|
||||
@@ -113,21 +115,25 @@ impl PDRouter {
|
||||
(results, errors)
|
||||
}
|
||||
|
||||
fn _get_worker_url_and_key(&self, w: &Arc<dyn Worker>) -> (String, Option<String>) {
|
||||
(w.url().to_string(), w.api_key().clone())
|
||||
}
|
||||
|
||||
// Helper to get prefill worker URLs
|
||||
fn get_prefill_worker_urls(&self) -> Vec<String> {
|
||||
fn get_prefill_worker_urls_with_api_key(&self) -> Vec<(String, Option<String>)> {
|
||||
self.worker_registry
|
||||
.get_prefill_workers()
|
||||
.iter()
|
||||
.map(|w| w.url().to_string())
|
||||
.map(|w| self._get_worker_url_and_key(w))
|
||||
.collect()
|
||||
}
|
||||
|
||||
// Helper to get decode worker URLs
|
||||
fn get_decode_worker_urls(&self) -> Vec<String> {
|
||||
fn get_decode_worker_urls_with_api_key(&self) -> Vec<(String, Option<String>)> {
|
||||
self.worker_registry
|
||||
.get_decode_workers()
|
||||
.iter()
|
||||
.map(|w| w.url().to_string())
|
||||
.map(|w| self._get_worker_url_and_key(w))
|
||||
.collect()
|
||||
}
|
||||
|
||||
@@ -208,6 +214,7 @@ impl PDRouter {
|
||||
pub async fn add_prefill_server(
|
||||
&self,
|
||||
url: String,
|
||||
api_key: Option<String>,
|
||||
bootstrap_port: Option<u16>,
|
||||
) -> Result<String, PDRouterError> {
|
||||
// Wait for the new server to be healthy
|
||||
@@ -220,10 +227,15 @@ impl PDRouter {
|
||||
|
||||
// Create Worker for the new prefill server with circuit breaker configuration
|
||||
// TODO: In IGW mode, fetch model_id from worker's /get_model_info endpoint
|
||||
let worker = BasicWorkerBuilder::new(url.clone())
|
||||
let worker_builder = BasicWorkerBuilder::new(url.clone())
|
||||
.worker_type(WorkerType::Prefill { bootstrap_port })
|
||||
.circuit_breaker_config(self.circuit_breaker_config.clone())
|
||||
.build();
|
||||
.circuit_breaker_config(self.circuit_breaker_config.clone());
|
||||
|
||||
let worker = if let Some(api_key) = api_key {
|
||||
worker_builder.api_key(api_key).build()
|
||||
} else {
|
||||
worker_builder.build()
|
||||
};
|
||||
|
||||
let worker_arc: Arc<dyn Worker> = Arc::new(worker);
|
||||
|
||||
@@ -243,7 +255,11 @@ impl PDRouter {
|
||||
Ok(format!("Successfully added prefill server: {}", url))
|
||||
}
|
||||
|
||||
pub async fn add_decode_server(&self, url: String) -> Result<String, PDRouterError> {
|
||||
pub async fn add_decode_server(
|
||||
&self,
|
||||
url: String,
|
||||
api_key: Option<String>,
|
||||
) -> Result<String, PDRouterError> {
|
||||
// Wait for the new server to be healthy
|
||||
self.wait_for_server_health(&url).await?;
|
||||
|
||||
@@ -254,10 +270,15 @@ impl PDRouter {
|
||||
|
||||
// Create Worker for the new decode server with circuit breaker configuration
|
||||
// TODO: In IGW mode, fetch model_id from worker's /get_model_info endpoint
|
||||
let worker = BasicWorkerBuilder::new(url.clone())
|
||||
let worker_builder = BasicWorkerBuilder::new(url.clone())
|
||||
.worker_type(WorkerType::Decode)
|
||||
.circuit_breaker_config(self.circuit_breaker_config.clone())
|
||||
.build();
|
||||
.circuit_breaker_config(self.circuit_breaker_config.clone());
|
||||
|
||||
let worker = if let Some(api_key) = api_key {
|
||||
worker_builder.api_key(api_key).build()
|
||||
} else {
|
||||
worker_builder.build()
|
||||
};
|
||||
|
||||
let worker_arc: Arc<dyn Worker> = Arc::new(worker);
|
||||
|
||||
@@ -366,6 +387,12 @@ impl PDRouter {
|
||||
.chain(decode_workers.iter())
|
||||
.map(|w| w.url().to_string())
|
||||
.collect();
|
||||
// Get all worker API keys for monitoring
|
||||
let all_api_keys: Vec<Option<String>> = prefill_workers
|
||||
.iter()
|
||||
.chain(decode_workers.iter())
|
||||
.map(|w| w.api_key().clone())
|
||||
.collect();
|
||||
|
||||
// Convert config CircuitBreakerConfig to core CircuitBreakerConfig
|
||||
let circuit_breaker_config = ctx.router_config.effective_circuit_breaker_config();
|
||||
@@ -387,6 +414,7 @@ impl PDRouter {
|
||||
let load_monitor_handle =
|
||||
if prefill_policy.name() == "power_of_two" || decode_policy.name() == "power_of_two" {
|
||||
let monitor_urls = all_urls.clone();
|
||||
let monitor_api_keys = all_api_keys.clone();
|
||||
let monitor_interval = ctx.router_config.worker_startup_check_interval_secs;
|
||||
let monitor_client = ctx.client.clone();
|
||||
let prefill_policy_clone = Arc::clone(&prefill_policy);
|
||||
@@ -395,6 +423,7 @@ impl PDRouter {
|
||||
Some(Arc::new(tokio::spawn(async move {
|
||||
Self::monitor_worker_loads_with_client(
|
||||
monitor_urls,
|
||||
monitor_api_keys,
|
||||
tx,
|
||||
monitor_interval,
|
||||
monitor_client,
|
||||
@@ -500,6 +529,7 @@ impl PDRouter {
|
||||
prefill_drain_tx,
|
||||
retry_config: ctx.router_config.effective_retry_config(),
|
||||
circuit_breaker_config: core_cb_config,
|
||||
api_key: ctx.router_config.api_key.clone(),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1150,6 +1180,7 @@ impl PDRouter {
|
||||
// Background task to monitor worker loads with shared client
|
||||
async fn monitor_worker_loads_with_client(
|
||||
worker_urls: Vec<String>,
|
||||
worker_api_keys: Vec<Option<String>>,
|
||||
tx: tokio::sync::watch::Sender<HashMap<String, isize>>,
|
||||
interval_secs: u64,
|
||||
client: Client,
|
||||
@@ -1161,11 +1192,13 @@ impl PDRouter {
|
||||
|
||||
let futures: Vec<_> = worker_urls
|
||||
.iter()
|
||||
.map(|url| {
|
||||
.zip(worker_api_keys.iter())
|
||||
.map(|(url, api_key)| {
|
||||
let client = client.clone();
|
||||
let url = url.clone();
|
||||
let api_key = api_key.clone();
|
||||
async move {
|
||||
let load = get_worker_load(&client, &url).await.unwrap_or(0);
|
||||
let load = get_worker_load(&client, &url, &api_key).await.unwrap_or(0);
|
||||
(url, load)
|
||||
}
|
||||
})
|
||||
@@ -1515,8 +1548,16 @@ impl PDRouter {
|
||||
|
||||
// Helper functions
|
||||
|
||||
async fn get_worker_load(client: &Client, worker_url: &str) -> Option<isize> {
|
||||
match client.get(format!("{}/get_load", worker_url)).send().await {
|
||||
async fn get_worker_load(
|
||||
client: &Client,
|
||||
worker_url: &str,
|
||||
api_key: &Option<String>,
|
||||
) -> Option<isize> {
|
||||
let mut req_builder = client.get(format!("{}/get_load", worker_url));
|
||||
if let Some(key) = api_key {
|
||||
req_builder = req_builder.bearer_auth(key);
|
||||
}
|
||||
match req_builder.send().await {
|
||||
Ok(res) if res.status().is_success() => match res.bytes().await {
|
||||
Ok(bytes) => match serde_json::from_slice::<Value>(&bytes) {
|
||||
Ok(data) => data
|
||||
@@ -1550,7 +1591,11 @@ async fn get_worker_load(client: &Client, worker_url: &str) -> Option<isize> {
|
||||
|
||||
#[async_trait]
|
||||
impl WorkerManagement for PDRouter {
|
||||
async fn add_worker(&self, _worker_url: &str) -> Result<String, String> {
|
||||
async fn add_worker(
|
||||
&self,
|
||||
_worker_url: &str,
|
||||
_api_key: &Option<String>,
|
||||
) -> Result<String, String> {
|
||||
// For PD router, we don't support adding workers via this generic method
|
||||
Err(
|
||||
"PD router requires specific add_prefill_server or add_decode_server methods"
|
||||
@@ -1956,9 +2001,9 @@ impl RouterTrait for PDRouter {
|
||||
let mut errors = Vec::new();
|
||||
|
||||
// Process prefill workers
|
||||
let prefill_urls = self.get_prefill_worker_urls();
|
||||
for worker_url in prefill_urls {
|
||||
match get_worker_load(&self.client, &worker_url).await {
|
||||
let prefill_urls_with_key = self.get_prefill_worker_urls_with_api_key();
|
||||
for (worker_url, api_key) in prefill_urls_with_key {
|
||||
match get_worker_load(&self.client, &worker_url, &api_key).await {
|
||||
Some(load) => {
|
||||
loads.insert(format!("prefill_{}", worker_url), load);
|
||||
}
|
||||
@@ -1969,9 +2014,9 @@ impl RouterTrait for PDRouter {
|
||||
}
|
||||
|
||||
// Process decode workers
|
||||
let decode_urls = self.get_decode_worker_urls();
|
||||
for worker_url in decode_urls {
|
||||
match get_worker_load(&self.client, &worker_url).await {
|
||||
let decode_urls_with_key = self.get_decode_worker_urls_with_api_key();
|
||||
for (worker_url, api_key) in decode_urls_with_key {
|
||||
match get_worker_load(&self.client, &worker_url, &api_key).await {
|
||||
Some(load) => {
|
||||
loads.insert(format!("decode_{}", worker_url), load);
|
||||
}
|
||||
@@ -2069,12 +2114,14 @@ mod tests {
|
||||
prefill_drain_tx: mpsc::channel(100).0,
|
||||
retry_config: RetryConfig::default(),
|
||||
circuit_breaker_config: CircuitBreakerConfig::default(),
|
||||
api_key: Some("test_api_key".to_string()),
|
||||
}
|
||||
}
|
||||
|
||||
fn create_test_worker(url: String, worker_type: WorkerType, healthy: bool) -> Box<dyn Worker> {
|
||||
let worker = BasicWorkerBuilder::new(url)
|
||||
.worker_type(worker_type)
|
||||
.api_key("test_api_key")
|
||||
.build();
|
||||
worker.set_healthy(healthy);
|
||||
Box::new(worker)
|
||||
|
||||
@@ -38,6 +38,7 @@ pub struct Router {
|
||||
worker_startup_timeout_secs: u64,
|
||||
worker_startup_check_interval_secs: u64,
|
||||
dp_aware: bool,
|
||||
#[allow(dead_code)]
|
||||
api_key: Option<String>,
|
||||
retry_config: RetryConfig,
|
||||
circuit_breaker_config: CircuitBreakerConfig,
|
||||
@@ -71,7 +72,6 @@ impl Router {
|
||||
};
|
||||
|
||||
// Cache-aware policies are initialized in WorkerInitializer
|
||||
|
||||
// Setup load monitoring for PowerOfTwo policy
|
||||
let (tx, rx) = tokio::sync::watch::channel(HashMap::new());
|
||||
let worker_loads = Arc::new(rx);
|
||||
@@ -82,6 +82,14 @@ impl Router {
|
||||
// Check if default policy is power_of_two for load monitoring
|
||||
let load_monitor_handle = if default_policy.name() == "power_of_two" {
|
||||
let monitor_urls = worker_urls.clone();
|
||||
let monitor_api_keys = monitor_urls
|
||||
.iter()
|
||||
.map(|url| {
|
||||
ctx.worker_registry
|
||||
.get_by_url(url)
|
||||
.and_then(|w| w.api_key().clone())
|
||||
})
|
||||
.collect::<Vec<Option<String>>>();
|
||||
let monitor_interval = ctx.router_config.worker_startup_check_interval_secs;
|
||||
let policy_clone = default_policy.clone();
|
||||
let client_clone = ctx.client.clone();
|
||||
@@ -89,6 +97,7 @@ impl Router {
|
||||
Some(Arc::new(tokio::spawn(async move {
|
||||
Self::monitor_worker_loads(
|
||||
monitor_urls,
|
||||
monitor_api_keys,
|
||||
tx,
|
||||
monitor_interval,
|
||||
policy_clone,
|
||||
@@ -912,7 +921,11 @@ impl Router {
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn add_worker(&self, worker_url: &str) -> Result<String, String> {
|
||||
pub async fn add_worker(
|
||||
&self,
|
||||
worker_url: &str,
|
||||
api_key: &Option<String>,
|
||||
) -> Result<String, String> {
|
||||
let start_time = std::time::Instant::now();
|
||||
let client = reqwest::Client::builder()
|
||||
.timeout(Duration::from_secs(self.worker_startup_timeout_secs))
|
||||
@@ -938,7 +951,7 @@ impl Router {
|
||||
// Need to contact the worker to extract the dp_size,
|
||||
// and add them as multiple workers
|
||||
let url_vec = vec![String::from(worker_url)];
|
||||
let dp_url_vec = Self::get_dp_aware_workers(&url_vec, &self.api_key)
|
||||
let dp_url_vec = Self::get_dp_aware_workers(&url_vec, api_key)
|
||||
.map_err(|e| format!("Failed to get dp-aware workers: {}", e))?;
|
||||
let mut worker_added: bool = false;
|
||||
for dp_url in &dp_url_vec {
|
||||
@@ -948,10 +961,18 @@ impl Router {
|
||||
}
|
||||
info!("Added worker: {}", dp_url);
|
||||
// TODO: In IGW mode, fetch model_id from worker's /get_model_info endpoint
|
||||
let new_worker = BasicWorkerBuilder::new(dp_url.to_string())
|
||||
.worker_type(WorkerType::Regular)
|
||||
.circuit_breaker_config(self.circuit_breaker_config.clone())
|
||||
.build();
|
||||
let new_worker_builder =
|
||||
BasicWorkerBuilder::new(dp_url.to_string())
|
||||
.worker_type(WorkerType::Regular)
|
||||
.circuit_breaker_config(
|
||||
self.circuit_breaker_config.clone(),
|
||||
);
|
||||
|
||||
let new_worker = if let Some(api_key) = api_key {
|
||||
new_worker_builder.api_key(api_key).build()
|
||||
} else {
|
||||
new_worker_builder.build()
|
||||
};
|
||||
|
||||
let worker_arc = Arc::new(new_worker);
|
||||
self.worker_registry.register(worker_arc.clone());
|
||||
@@ -978,10 +999,16 @@ impl Router {
|
||||
info!("Added worker: {}", worker_url);
|
||||
|
||||
// TODO: In IGW mode, fetch model_id from worker's /get_model_info endpoint
|
||||
let new_worker = BasicWorkerBuilder::new(worker_url.to_string())
|
||||
.worker_type(WorkerType::Regular)
|
||||
.circuit_breaker_config(self.circuit_breaker_config.clone())
|
||||
.build();
|
||||
let new_worker_builder =
|
||||
BasicWorkerBuilder::new(worker_url.to_string())
|
||||
.worker_type(WorkerType::Regular)
|
||||
.circuit_breaker_config(self.circuit_breaker_config.clone());
|
||||
|
||||
let new_worker = if let Some(api_key) = api_key {
|
||||
new_worker_builder.api_key(api_key).build()
|
||||
} else {
|
||||
new_worker_builder.build()
|
||||
};
|
||||
|
||||
let worker_arc = Arc::new(new_worker);
|
||||
self.worker_registry.register(worker_arc.clone());
|
||||
@@ -1094,7 +1121,7 @@ impl Router {
|
||||
}
|
||||
}
|
||||
|
||||
async fn get_worker_load(&self, worker_url: &str) -> Option<isize> {
|
||||
async fn get_worker_load(&self, worker_url: &str, api_key: &Option<String>) -> Option<isize> {
|
||||
let worker_url = if self.dp_aware {
|
||||
// Need to extract the URL from "http://host:port@dp_rank"
|
||||
let (worker_url_prefix, _dp_rank) = match Self::extract_dp_rank(worker_url) {
|
||||
@@ -1109,12 +1136,12 @@ impl Router {
|
||||
worker_url
|
||||
};
|
||||
|
||||
match self
|
||||
.client
|
||||
.get(format!("{}/get_load", worker_url))
|
||||
.send()
|
||||
.await
|
||||
{
|
||||
let mut req_builder = self.client.get(format!("{}/get_load", worker_url));
|
||||
if let Some(key) = api_key {
|
||||
req_builder = req_builder.bearer_auth(key);
|
||||
}
|
||||
|
||||
match req_builder.send().await {
|
||||
Ok(res) if res.status().is_success() => match res.bytes().await {
|
||||
Ok(bytes) => match serde_json::from_slice::<serde_json::Value>(&bytes) {
|
||||
Ok(data) => data
|
||||
@@ -1149,6 +1176,7 @@ impl Router {
|
||||
// Background task to monitor worker loads
|
||||
async fn monitor_worker_loads(
|
||||
worker_urls: Vec<String>,
|
||||
worker_api_keys: Vec<Option<String>>,
|
||||
tx: tokio::sync::watch::Sender<HashMap<String, isize>>,
|
||||
interval_secs: u64,
|
||||
policy: Arc<dyn LoadBalancingPolicy>,
|
||||
@@ -1160,8 +1188,8 @@ impl Router {
|
||||
interval.tick().await;
|
||||
|
||||
let mut loads = HashMap::new();
|
||||
for url in &worker_urls {
|
||||
if let Some(load) = Self::get_worker_load_static(&client, url).await {
|
||||
for (url, api_key) in worker_urls.iter().zip(worker_api_keys.iter()) {
|
||||
if let Some(load) = Self::get_worker_load_static(&client, url, api_key).await {
|
||||
loads.insert(url.clone(), load);
|
||||
}
|
||||
}
|
||||
@@ -1179,7 +1207,11 @@ impl Router {
|
||||
}
|
||||
|
||||
// Static version of get_worker_load for use in monitoring task
|
||||
async fn get_worker_load_static(client: &reqwest::Client, worker_url: &str) -> Option<isize> {
|
||||
async fn get_worker_load_static(
|
||||
client: &reqwest::Client,
|
||||
worker_url: &str,
|
||||
api_key: &Option<String>,
|
||||
) -> Option<isize> {
|
||||
let worker_url = if worker_url.contains("@") {
|
||||
// Need to extract the URL from "http://host:port@dp_rank"
|
||||
let (worker_url_prefix, _dp_rank) = match Self::extract_dp_rank(worker_url) {
|
||||
@@ -1194,7 +1226,11 @@ impl Router {
|
||||
worker_url
|
||||
};
|
||||
|
||||
match client.get(format!("{}/get_load", worker_url)).send().await {
|
||||
let mut req_builder = client.get(format!("{}/get_load", worker_url));
|
||||
if let Some(key) = api_key {
|
||||
req_builder = req_builder.bearer_auth(key);
|
||||
}
|
||||
match req_builder.send().await {
|
||||
Ok(res) if res.status().is_success() => match res.bytes().await {
|
||||
Ok(bytes) => match serde_json::from_slice::<serde_json::Value>(&bytes) {
|
||||
Ok(data) => data
|
||||
@@ -1250,8 +1286,12 @@ use async_trait::async_trait;
|
||||
|
||||
#[async_trait]
|
||||
impl WorkerManagement for Router {
|
||||
async fn add_worker(&self, worker_url: &str) -> Result<String, String> {
|
||||
Router::add_worker(self, worker_url).await
|
||||
async fn add_worker(
|
||||
&self,
|
||||
worker_url: &str,
|
||||
api_key: &Option<String>,
|
||||
) -> Result<String, String> {
|
||||
Router::add_worker(self, worker_url, api_key).await
|
||||
}
|
||||
|
||||
fn remove_worker(&self, worker_url: &str) {
|
||||
@@ -1457,12 +1497,12 @@ impl RouterTrait for Router {
|
||||
}
|
||||
|
||||
async fn get_worker_loads(&self) -> Response {
|
||||
let urls = self.get_worker_urls();
|
||||
let urls_with_key = self.worker_registry.get_all_urls_with_api_key();
|
||||
let mut loads = Vec::new();
|
||||
|
||||
// Get loads from all workers
|
||||
for url in &urls {
|
||||
let load = self.get_worker_load(url).await.unwrap_or(-1);
|
||||
for (url, api_key) in &urls_with_key {
|
||||
let load = self.get_worker_load(url, api_key).await.unwrap_or(-1);
|
||||
loads.push(serde_json::json!({
|
||||
"worker": url,
|
||||
"load": load
|
||||
@@ -1521,9 +1561,11 @@ mod tests {
|
||||
// Register test workers
|
||||
let worker1 = BasicWorkerBuilder::new("http://worker1:8080")
|
||||
.worker_type(WorkerType::Regular)
|
||||
.api_key("test_api_key")
|
||||
.build();
|
||||
let worker2 = BasicWorkerBuilder::new("http://worker2:8080")
|
||||
.worker_type(WorkerType::Regular)
|
||||
.api_key("test_api_key")
|
||||
.build();
|
||||
worker_registry.register(Arc::new(worker1));
|
||||
worker_registry.register(Arc::new(worker2));
|
||||
|
||||
@@ -33,7 +33,11 @@ pub use http::{openai_router, pd_router, pd_types, router};
|
||||
#[async_trait]
|
||||
pub trait WorkerManagement: Send + Sync {
|
||||
/// Add a worker to the router
|
||||
async fn add_worker(&self, worker_url: &str) -> Result<String, String>;
|
||||
async fn add_worker(
|
||||
&self,
|
||||
worker_url: &str,
|
||||
api_key: &Option<String>,
|
||||
) -> Result<String, String>;
|
||||
|
||||
/// Remove a worker from the router
|
||||
fn remove_worker(&self, worker_url: &str);
|
||||
|
||||
@@ -161,7 +161,7 @@ impl RouterManager {
|
||||
let model_id = if let Some(model_id) = config.model_id {
|
||||
model_id
|
||||
} else {
|
||||
match self.query_server_info(&config.url).await {
|
||||
match self.query_server_info(&config.url, &config.api_key).await {
|
||||
Ok(info) => {
|
||||
// Extract model_id from server info
|
||||
info.model_id
|
||||
@@ -208,29 +208,44 @@ impl RouterManager {
|
||||
}
|
||||
|
||||
let worker = match config.worker_type.as_deref() {
|
||||
Some("prefill") => Box::new(
|
||||
BasicWorkerBuilder::new(config.url.clone())
|
||||
Some("prefill") => {
|
||||
let mut builder = BasicWorkerBuilder::new(config.url.clone())
|
||||
.worker_type(WorkerType::Prefill {
|
||||
bootstrap_port: config.bootstrap_port,
|
||||
})
|
||||
.labels(labels.clone())
|
||||
.circuit_breaker_config(CircuitBreakerConfig::default())
|
||||
.build(),
|
||||
) as Box<dyn Worker>,
|
||||
Some("decode") => Box::new(
|
||||
BasicWorkerBuilder::new(config.url.clone())
|
||||
.circuit_breaker_config(CircuitBreakerConfig::default());
|
||||
|
||||
if let Some(api_key) = config.api_key.clone() {
|
||||
builder = builder.api_key(api_key);
|
||||
}
|
||||
|
||||
Box::new(builder.build()) as Box<dyn Worker>
|
||||
}
|
||||
Some("decode") => {
|
||||
let mut builder = BasicWorkerBuilder::new(config.url.clone())
|
||||
.worker_type(WorkerType::Decode)
|
||||
.labels(labels.clone())
|
||||
.circuit_breaker_config(CircuitBreakerConfig::default())
|
||||
.build(),
|
||||
) as Box<dyn Worker>,
|
||||
_ => Box::new(
|
||||
BasicWorkerBuilder::new(config.url.clone())
|
||||
.circuit_breaker_config(CircuitBreakerConfig::default());
|
||||
|
||||
if let Some(api_key) = config.api_key.clone() {
|
||||
builder = builder.api_key(api_key);
|
||||
}
|
||||
|
||||
Box::new(builder.build()) as Box<dyn Worker>
|
||||
}
|
||||
_ => {
|
||||
let mut builder = BasicWorkerBuilder::new(config.url.clone())
|
||||
.worker_type(WorkerType::Regular)
|
||||
.labels(labels.clone())
|
||||
.circuit_breaker_config(CircuitBreakerConfig::default())
|
||||
.build(),
|
||||
) as Box<dyn Worker>,
|
||||
.circuit_breaker_config(CircuitBreakerConfig::default());
|
||||
|
||||
if let Some(api_key) = config.api_key.clone() {
|
||||
builder = builder.api_key(api_key);
|
||||
}
|
||||
|
||||
Box::new(builder.build()) as Box<dyn Worker>
|
||||
}
|
||||
};
|
||||
|
||||
// Register worker
|
||||
@@ -346,10 +361,18 @@ impl RouterManager {
|
||||
}
|
||||
|
||||
/// Query server info from a worker URL
|
||||
async fn query_server_info(&self, url: &str) -> Result<ServerInfo, String> {
|
||||
async fn query_server_info(
|
||||
&self,
|
||||
url: &str,
|
||||
api_key: &Option<String>,
|
||||
) -> Result<ServerInfo, String> {
|
||||
let info_url = format!("{}/get_server_info", url.trim_end_matches('/'));
|
||||
|
||||
match self.client.get(&info_url).send().await {
|
||||
let mut req_builder = self.client.get(&info_url);
|
||||
if let Some(key) = api_key {
|
||||
req_builder = req_builder.bearer_auth(key);
|
||||
}
|
||||
match req_builder.send().await {
|
||||
Ok(response) => {
|
||||
if response.status().is_success() {
|
||||
response
|
||||
@@ -477,10 +500,15 @@ impl RouterManager {
|
||||
#[async_trait]
|
||||
impl WorkerManagement for RouterManager {
|
||||
/// Add a worker - in multi-router mode, this adds to the registry
|
||||
async fn add_worker(&self, worker_url: &str) -> Result<String, String> {
|
||||
async fn add_worker(
|
||||
&self,
|
||||
worker_url: &str,
|
||||
api_key: &Option<String>,
|
||||
) -> Result<String, String> {
|
||||
// Create a basic worker config request
|
||||
let config = WorkerConfigRequest {
|
||||
url: worker_url.to_string(),
|
||||
api_key: api_key.clone(),
|
||||
model_id: None,
|
||||
worker_type: None,
|
||||
priority: None,
|
||||
|
||||
@@ -27,8 +27,12 @@ impl WorkerInitializer {
|
||||
|
||||
match &config.mode {
|
||||
RoutingMode::Regular { worker_urls } => {
|
||||
// use router's api_key, repeat for each worker
|
||||
let worker_api_keys: Vec<Option<String>> =
|
||||
worker_urls.iter().map(|_| config.api_key.clone()).collect();
|
||||
Self::create_regular_workers(
|
||||
worker_urls,
|
||||
&worker_api_keys,
|
||||
&config.connection_mode,
|
||||
config,
|
||||
worker_registry,
|
||||
@@ -41,8 +45,16 @@ impl WorkerInitializer {
|
||||
decode_urls,
|
||||
..
|
||||
} => {
|
||||
// use router's api_key, repeat for each prefill/decode worker
|
||||
let prefill_api_keys: Vec<Option<String>> = prefill_urls
|
||||
.iter()
|
||||
.map(|_| config.api_key.clone())
|
||||
.collect();
|
||||
let decode_api_keys: Vec<Option<String>> =
|
||||
decode_urls.iter().map(|_| config.api_key.clone()).collect();
|
||||
Self::create_prefill_workers(
|
||||
prefill_urls,
|
||||
&prefill_api_keys,
|
||||
&config.connection_mode,
|
||||
config,
|
||||
worker_registry,
|
||||
@@ -51,6 +63,7 @@ impl WorkerInitializer {
|
||||
.await?;
|
||||
Self::create_decode_workers(
|
||||
decode_urls,
|
||||
&decode_api_keys,
|
||||
&config.connection_mode,
|
||||
config,
|
||||
worker_registry,
|
||||
@@ -79,6 +92,7 @@ impl WorkerInitializer {
|
||||
/// Create regular workers for standard routing mode
|
||||
async fn create_regular_workers(
|
||||
urls: &[String],
|
||||
api_keys: &[Option<String>],
|
||||
config_connection_mode: &ConfigConnectionMode,
|
||||
config: &RouterConfig,
|
||||
registry: &Arc<WorkerRegistry>,
|
||||
@@ -109,14 +123,18 @@ impl WorkerInitializer {
|
||||
|
||||
let mut registered_workers: HashMap<String, Vec<Arc<dyn Worker>>> = HashMap::new();
|
||||
|
||||
for url in urls {
|
||||
for (url, api_key) in urls.iter().zip(api_keys.iter()) {
|
||||
// TODO: Add DP-aware support when we have dp_rank/dp_size info
|
||||
let worker = BasicWorkerBuilder::new(url.clone())
|
||||
let worker_builder = BasicWorkerBuilder::new(url.clone())
|
||||
.worker_type(WorkerType::Regular)
|
||||
.connection_mode(connection_mode.clone())
|
||||
.circuit_breaker_config(core_cb_config.clone())
|
||||
.health_config(health_config.clone())
|
||||
.build();
|
||||
.health_config(health_config.clone());
|
||||
let worker = if let Some(api_key) = api_key.clone() {
|
||||
worker_builder.api_key(api_key).build()
|
||||
} else {
|
||||
worker_builder.build()
|
||||
};
|
||||
|
||||
let worker_arc = Arc::new(worker) as Arc<dyn Worker>;
|
||||
let model_id = worker_arc.model_id();
|
||||
@@ -148,6 +166,7 @@ impl WorkerInitializer {
|
||||
/// Create prefill workers for disaggregated routing mode
|
||||
async fn create_prefill_workers(
|
||||
prefill_entries: &[(String, Option<u16>)],
|
||||
api_keys: &[Option<String>],
|
||||
config_connection_mode: &ConfigConnectionMode,
|
||||
config: &RouterConfig,
|
||||
registry: &Arc<WorkerRegistry>,
|
||||
@@ -181,16 +200,20 @@ impl WorkerInitializer {
|
||||
|
||||
let mut registered_workers: HashMap<String, Vec<Arc<dyn Worker>>> = HashMap::new();
|
||||
|
||||
for (url, bootstrap_port) in prefill_entries {
|
||||
for ((url, bootstrap_port), api_key) in prefill_entries.iter().zip(api_keys.iter()) {
|
||||
// TODO: Add DP-aware support when we have dp_rank/dp_size info
|
||||
let worker = BasicWorkerBuilder::new(url.clone())
|
||||
let worker_builder = BasicWorkerBuilder::new(url.clone())
|
||||
.worker_type(WorkerType::Prefill {
|
||||
bootstrap_port: *bootstrap_port,
|
||||
})
|
||||
.connection_mode(connection_mode.clone())
|
||||
.circuit_breaker_config(core_cb_config.clone())
|
||||
.health_config(health_config.clone())
|
||||
.build();
|
||||
.health_config(health_config.clone());
|
||||
let worker = if let Some(api_key) = api_key.clone() {
|
||||
worker_builder.api_key(api_key).build()
|
||||
} else {
|
||||
worker_builder.build()
|
||||
};
|
||||
|
||||
let worker_arc = Arc::new(worker) as Arc<dyn Worker>;
|
||||
let model_id = worker_arc.model_id();
|
||||
@@ -227,6 +250,7 @@ impl WorkerInitializer {
|
||||
/// Create decode workers for disaggregated routing mode
|
||||
async fn create_decode_workers(
|
||||
urls: &[String],
|
||||
api_keys: &[Option<String>],
|
||||
config_connection_mode: &ConfigConnectionMode,
|
||||
config: &RouterConfig,
|
||||
registry: &Arc<WorkerRegistry>,
|
||||
@@ -257,14 +281,18 @@ impl WorkerInitializer {
|
||||
|
||||
let mut registered_workers: HashMap<String, Vec<Arc<dyn Worker>>> = HashMap::new();
|
||||
|
||||
for url in urls {
|
||||
for (url, api_key) in urls.iter().zip(api_keys.iter()) {
|
||||
// TODO: Add DP-aware support when we have dp_rank/dp_size info
|
||||
let worker = BasicWorkerBuilder::new(url.clone())
|
||||
let worker_builder = BasicWorkerBuilder::new(url.clone())
|
||||
.worker_type(WorkerType::Decode)
|
||||
.connection_mode(connection_mode.clone())
|
||||
.circuit_breaker_config(core_cb_config.clone())
|
||||
.health_config(health_config.clone())
|
||||
.build();
|
||||
.health_config(health_config.clone());
|
||||
let worker = if let Some(api_key) = api_key.clone() {
|
||||
worker_builder.api_key(api_key).build()
|
||||
} else {
|
||||
worker_builder.build()
|
||||
};
|
||||
|
||||
let worker_arc = Arc::new(worker) as Arc<dyn Worker>;
|
||||
let model_id = worker_arc.model_id();
|
||||
|
||||
Reference in New Issue
Block a user