refactor(pd-router): extract common patterns to reduce code duplication (#9081)
This commit is contained in:
@@ -72,6 +72,138 @@ impl PDRouter {
|
||||
})
|
||||
}
|
||||
|
||||
// Generic helper for processing all workers with an endpoint
|
||||
async fn process_workers(
|
||||
&self,
|
||||
workers: &RwLock<Vec<Box<dyn Worker>>>,
|
||||
worker_type: &str,
|
||||
endpoint: &str,
|
||||
) -> (Vec<String>, Vec<String>) {
|
||||
let mut results = Vec::new();
|
||||
let mut errors = Vec::new();
|
||||
|
||||
// Get worker URLs first to avoid holding lock across await
|
||||
let urls = match workers.read() {
|
||||
Ok(workers) => workers
|
||||
.iter()
|
||||
.map(|w| w.url().to_string())
|
||||
.collect::<Vec<_>>(),
|
||||
Err(_) => {
|
||||
errors.push(format!("Failed to access {} workers", worker_type));
|
||||
Vec::new()
|
||||
}
|
||||
};
|
||||
|
||||
// Process each worker
|
||||
for worker_url in urls {
|
||||
let url = format!("{}/{}", worker_url, endpoint);
|
||||
match self.client.post(&url).send().await {
|
||||
Ok(res) if res.status().is_success() => {
|
||||
results.push(format!("{} {}: OK", worker_type, worker_url));
|
||||
}
|
||||
Ok(res) => {
|
||||
errors.push(format!(
|
||||
"{} {} returned status: {}",
|
||||
worker_type,
|
||||
worker_url,
|
||||
res.status()
|
||||
));
|
||||
}
|
||||
Err(e) => {
|
||||
errors.push(format!("{} {} error: {}", worker_type, worker_url, e));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
(results, errors)
|
||||
}
|
||||
|
||||
// Helper to get worker URLs from a worker collection
|
||||
fn get_worker_urls(
|
||||
workers: &RwLock<Vec<Box<dyn Worker>>>,
|
||||
worker_type: &str,
|
||||
) -> Result<Vec<String>, String> {
|
||||
workers
|
||||
.read()
|
||||
.map(|workers| {
|
||||
workers
|
||||
.iter()
|
||||
.map(|w| w.url().to_string())
|
||||
.collect::<Vec<_>>()
|
||||
})
|
||||
.map_err(|_| format!("Failed to access {} workers", worker_type))
|
||||
}
|
||||
|
||||
// Generic helper for proxying requests to the first worker
|
||||
async fn proxy_to_first_worker(
|
||||
&self,
|
||||
workers: &RwLock<Vec<Box<dyn Worker>>>,
|
||||
endpoint: &str,
|
||||
worker_type: &str,
|
||||
headers: Option<Vec<(String, String)>>,
|
||||
) -> Response {
|
||||
// Get first worker URL to avoid holding lock across await
|
||||
let first_worker_url = match workers.read() {
|
||||
Ok(workers) => workers.first().map(|w| w.url().to_string()),
|
||||
Err(_) => {
|
||||
return (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Failed to access {} workers", worker_type),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
};
|
||||
|
||||
if let Some(worker_url) = first_worker_url {
|
||||
let url = format!("{}/{}", worker_url, endpoint);
|
||||
let mut request_builder = self.client.get(&url);
|
||||
|
||||
// Add headers if provided
|
||||
if let Some(headers) = headers {
|
||||
for (name, value) in headers {
|
||||
request_builder = request_builder.header(name, value);
|
||||
}
|
||||
}
|
||||
|
||||
match request_builder.send().await {
|
||||
Ok(res) if res.status().is_success() => match res.bytes().await {
|
||||
Ok(body) => (StatusCode::OK, body).into_response(),
|
||||
Err(e) => {
|
||||
error!("Failed to read response body: {}", e);
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Failed to read response body: {}", e),
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
},
|
||||
Ok(res) => {
|
||||
let status = StatusCode::from_u16(res.status().as_u16())
|
||||
.unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
|
||||
(
|
||||
status,
|
||||
format!("{} server returned status: {}", worker_type, res.status()),
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Failed to proxy request to {} server: {}", worker_type, e);
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Failed to proxy request: {}", e),
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
}
|
||||
} else {
|
||||
(
|
||||
StatusCode::SERVICE_UNAVAILABLE,
|
||||
format!("No {} servers available", worker_type),
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn add_prefill_server(
|
||||
&self,
|
||||
url: String,
|
||||
@@ -1384,191 +1516,32 @@ impl RouterTrait for PDRouter {
|
||||
|
||||
async fn get_server_info(&self, _req: Request<Body>) -> Response {
|
||||
// Get info from the first decode server to match sglang's server info format
|
||||
let first_decode_url = if let Ok(workers) = self.decode_workers.read() {
|
||||
workers.first().map(|w| w.url().to_string())
|
||||
} else {
|
||||
return (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
"Failed to access decode workers",
|
||||
)
|
||||
.into_response();
|
||||
};
|
||||
|
||||
if let Some(worker_url) = first_decode_url {
|
||||
match self
|
||||
.client
|
||||
.get(format!("{}/get_server_info", worker_url))
|
||||
.send()
|
||||
.await
|
||||
{
|
||||
Ok(res) if res.status().is_success() => {
|
||||
match res.json::<Value>().await {
|
||||
Ok(info) => {
|
||||
// The decode server should already return the proper format
|
||||
// with tokenizer_path and other fields that bench_one_batch_server.py expects
|
||||
Json(info).into_response()
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Failed to parse server info: {}", e);
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Failed to parse server info: {}", e),
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(res) => {
|
||||
let status = StatusCode::from_u16(res.status().as_u16())
|
||||
.unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
|
||||
(
|
||||
status,
|
||||
format!("Decode server returned status: {}", res.status()),
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Failed to get server info: {}", e);
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Failed to get server info: {}", e),
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
}
|
||||
} else {
|
||||
(
|
||||
StatusCode::SERVICE_UNAVAILABLE,
|
||||
"No decode servers available",
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
// Note: We use decode workers for server info to match expected format
|
||||
self.proxy_to_first_worker(&self.decode_workers, "get_server_info", "decode", None)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn get_models(&self, req: Request<Body>) -> Response {
|
||||
// Extract headers first to avoid Send issues
|
||||
let headers = crate::routers::router::copy_request_headers(&req);
|
||||
|
||||
// Get first prefill worker URL to avoid holding lock across await
|
||||
let first_worker_url = if let Ok(workers) = self.prefill_workers.read() {
|
||||
workers.first().map(|w| w.url().to_string())
|
||||
} else {
|
||||
return (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
"Failed to access prefill workers",
|
||||
)
|
||||
.into_response();
|
||||
};
|
||||
|
||||
if let Some(worker_url) = first_worker_url {
|
||||
let url = format!("{}/v1/models", worker_url);
|
||||
let mut request_builder = self.client.get(&url);
|
||||
|
||||
// Add headers
|
||||
for (name, value) in headers {
|
||||
request_builder = request_builder.header(name, value);
|
||||
}
|
||||
|
||||
match request_builder.send().await {
|
||||
Ok(res) if res.status().is_success() => match res.bytes().await {
|
||||
Ok(body) => (StatusCode::OK, body).into_response(),
|
||||
Err(e) => {
|
||||
error!("Failed to read response body: {}", e);
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Failed to read response body: {}", e),
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
},
|
||||
Ok(res) => {
|
||||
let status = StatusCode::from_u16(res.status().as_u16())
|
||||
.unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
|
||||
(
|
||||
status,
|
||||
format!("Prefill server returned status: {}", res.status()),
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Failed to get models: {}", e);
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Failed to get models: {}", e),
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
}
|
||||
} else {
|
||||
(
|
||||
StatusCode::SERVICE_UNAVAILABLE,
|
||||
"No prefill servers available",
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
// Proxy to first prefill worker
|
||||
self.proxy_to_first_worker(&self.prefill_workers, "v1/models", "prefill", Some(headers))
|
||||
.await
|
||||
}
|
||||
|
||||
async fn get_model_info(&self, req: Request<Body>) -> Response {
|
||||
// Extract headers first to avoid Send issues
|
||||
let headers = crate::routers::router::copy_request_headers(&req);
|
||||
|
||||
// Get first prefill worker URL to avoid holding lock across await
|
||||
let first_worker_url = if let Ok(workers) = self.prefill_workers.read() {
|
||||
workers.first().map(|w| w.url().to_string())
|
||||
} else {
|
||||
return (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
"Failed to access prefill workers",
|
||||
)
|
||||
.into_response();
|
||||
};
|
||||
|
||||
if let Some(worker_url) = first_worker_url {
|
||||
let url = format!("{}/get_model_info", worker_url);
|
||||
let mut request_builder = self.client.get(&url);
|
||||
|
||||
// Add headers
|
||||
for (name, value) in headers {
|
||||
request_builder = request_builder.header(name, value);
|
||||
}
|
||||
|
||||
match request_builder.send().await {
|
||||
Ok(res) if res.status().is_success() => match res.bytes().await {
|
||||
Ok(body) => (StatusCode::OK, body).into_response(),
|
||||
Err(e) => {
|
||||
error!("Failed to read response body: {}", e);
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Failed to read response body: {}", e),
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
},
|
||||
Ok(res) => {
|
||||
let status = StatusCode::from_u16(res.status().as_u16())
|
||||
.unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
|
||||
(
|
||||
status,
|
||||
format!("Prefill server returned status: {}", res.status()),
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Failed to get model info: {}", e);
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Failed to get model info: {}", e),
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
}
|
||||
} else {
|
||||
(
|
||||
StatusCode::SERVICE_UNAVAILABLE,
|
||||
"No prefill servers available",
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
// Proxy to first prefill worker
|
||||
self.proxy_to_first_worker(
|
||||
&self.prefill_workers,
|
||||
"get_model_info",
|
||||
"prefill",
|
||||
Some(headers),
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn route_generate(
|
||||
@@ -1692,70 +1665,19 @@ impl RouterTrait for PDRouter {
|
||||
}
|
||||
|
||||
async fn flush_cache(&self) -> Response {
|
||||
let mut results = Vec::new();
|
||||
let mut errors = Vec::new();
|
||||
// Process both prefill and decode workers
|
||||
let (prefill_results, prefill_errors) = self
|
||||
.process_workers(&self.prefill_workers, "Prefill", "flush_cache")
|
||||
.await;
|
||||
let (decode_results, decode_errors) = self
|
||||
.process_workers(&self.decode_workers, "Decode", "flush_cache")
|
||||
.await;
|
||||
|
||||
// Get prefill worker URLs first to avoid holding lock across await
|
||||
let prefill_urls = if let Ok(workers) = self.prefill_workers.read() {
|
||||
workers
|
||||
.iter()
|
||||
.map(|w| w.url().to_string())
|
||||
.collect::<Vec<_>>()
|
||||
} else {
|
||||
errors.push("Failed to access prefill workers".to_string());
|
||||
Vec::new()
|
||||
};
|
||||
|
||||
// Flush prefill workers
|
||||
for worker_url in prefill_urls {
|
||||
let url = format!("{}/flush_cache", worker_url);
|
||||
match self.client.post(&url).send().await {
|
||||
Ok(res) if res.status().is_success() => {
|
||||
results.push(format!("Prefill {}: OK", worker_url));
|
||||
}
|
||||
Ok(res) => {
|
||||
errors.push(format!(
|
||||
"Prefill {} returned status: {}",
|
||||
worker_url,
|
||||
res.status()
|
||||
));
|
||||
}
|
||||
Err(e) => {
|
||||
errors.push(format!("Prefill {} error: {}", worker_url, e));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Get decode worker URLs first to avoid holding lock across await
|
||||
let decode_urls = if let Ok(workers) = self.decode_workers.read() {
|
||||
workers
|
||||
.iter()
|
||||
.map(|w| w.url().to_string())
|
||||
.collect::<Vec<_>>()
|
||||
} else {
|
||||
errors.push("Failed to access decode workers".to_string());
|
||||
Vec::new()
|
||||
};
|
||||
|
||||
// Flush decode workers
|
||||
for worker_url in decode_urls {
|
||||
let url = format!("{}/flush_cache", worker_url);
|
||||
match self.client.post(&url).send().await {
|
||||
Ok(res) if res.status().is_success() => {
|
||||
results.push(format!("Decode {}: OK", worker_url));
|
||||
}
|
||||
Ok(res) => {
|
||||
errors.push(format!(
|
||||
"Decode {} returned status: {}",
|
||||
worker_url,
|
||||
res.status()
|
||||
));
|
||||
}
|
||||
Err(e) => {
|
||||
errors.push(format!("Decode {} error: {}", worker_url, e));
|
||||
}
|
||||
}
|
||||
}
|
||||
// Combine results and errors
|
||||
let mut results = prefill_results;
|
||||
results.extend(decode_results);
|
||||
let mut errors = prefill_errors;
|
||||
errors.extend(decode_errors);
|
||||
|
||||
if errors.is_empty() {
|
||||
(
|
||||
@@ -1779,50 +1701,38 @@ impl RouterTrait for PDRouter {
|
||||
let mut loads = HashMap::new();
|
||||
let mut errors = Vec::new();
|
||||
|
||||
// Get prefill worker URLs first to avoid holding lock across await
|
||||
let prefill_urls = if let Ok(workers) = self.prefill_workers.read() {
|
||||
workers
|
||||
.iter()
|
||||
.map(|w| w.url().to_string())
|
||||
.collect::<Vec<_>>()
|
||||
} else {
|
||||
errors.push("Failed to access prefill workers".to_string());
|
||||
Vec::new()
|
||||
};
|
||||
|
||||
// Get loads from prefill workers
|
||||
for worker_url in prefill_urls {
|
||||
match get_worker_load(&self.client, &worker_url).await {
|
||||
Some(load) => {
|
||||
loads.insert(format!("prefill_{}", worker_url), load);
|
||||
}
|
||||
None => {
|
||||
errors.push(format!("Failed to get load from prefill {}", worker_url));
|
||||
// Process prefill workers
|
||||
match Self::get_worker_urls(&self.prefill_workers, "prefill") {
|
||||
Ok(urls) => {
|
||||
for worker_url in urls {
|
||||
match get_worker_load(&self.client, &worker_url).await {
|
||||
Some(load) => {
|
||||
loads.insert(format!("prefill_{}", worker_url), load);
|
||||
}
|
||||
None => {
|
||||
errors.push(format!("Failed to get load from prefill {}", worker_url));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => errors.push(e),
|
||||
}
|
||||
|
||||
// Get decode worker URLs first to avoid holding lock across await
|
||||
let decode_urls = if let Ok(workers) = self.decode_workers.read() {
|
||||
workers
|
||||
.iter()
|
||||
.map(|w| w.url().to_string())
|
||||
.collect::<Vec<_>>()
|
||||
} else {
|
||||
errors.push("Failed to access decode workers".to_string());
|
||||
Vec::new()
|
||||
};
|
||||
|
||||
// Get loads from decode workers
|
||||
for worker_url in decode_urls {
|
||||
match get_worker_load(&self.client, &worker_url).await {
|
||||
Some(load) => {
|
||||
loads.insert(format!("decode_{}", worker_url), load);
|
||||
}
|
||||
None => {
|
||||
errors.push(format!("Failed to get load from decode {}", worker_url));
|
||||
// Process decode workers
|
||||
match Self::get_worker_urls(&self.decode_workers, "decode") {
|
||||
Ok(urls) => {
|
||||
for worker_url in urls {
|
||||
match get_worker_load(&self.client, &worker_url).await {
|
||||
Some(load) => {
|
||||
loads.insert(format!("decode_{}", worker_url), load);
|
||||
}
|
||||
None => {
|
||||
errors.push(format!("Failed to get load from decode {}", worker_url));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => errors.push(e),
|
||||
}
|
||||
|
||||
let response_data = serde_json::json!({
|
||||
|
||||
Reference in New Issue
Block a user