refactor(pd-router): extract common patterns to reduce code duplication (#9081)
This commit is contained in:
@@ -72,6 +72,138 @@ impl PDRouter {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Generic helper for processing all workers with an endpoint
|
||||||
|
async fn process_workers(
|
||||||
|
&self,
|
||||||
|
workers: &RwLock<Vec<Box<dyn Worker>>>,
|
||||||
|
worker_type: &str,
|
||||||
|
endpoint: &str,
|
||||||
|
) -> (Vec<String>, Vec<String>) {
|
||||||
|
let mut results = Vec::new();
|
||||||
|
let mut errors = Vec::new();
|
||||||
|
|
||||||
|
// Get worker URLs first to avoid holding lock across await
|
||||||
|
let urls = match workers.read() {
|
||||||
|
Ok(workers) => workers
|
||||||
|
.iter()
|
||||||
|
.map(|w| w.url().to_string())
|
||||||
|
.collect::<Vec<_>>(),
|
||||||
|
Err(_) => {
|
||||||
|
errors.push(format!("Failed to access {} workers", worker_type));
|
||||||
|
Vec::new()
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Process each worker
|
||||||
|
for worker_url in urls {
|
||||||
|
let url = format!("{}/{}", worker_url, endpoint);
|
||||||
|
match self.client.post(&url).send().await {
|
||||||
|
Ok(res) if res.status().is_success() => {
|
||||||
|
results.push(format!("{} {}: OK", worker_type, worker_url));
|
||||||
|
}
|
||||||
|
Ok(res) => {
|
||||||
|
errors.push(format!(
|
||||||
|
"{} {} returned status: {}",
|
||||||
|
worker_type,
|
||||||
|
worker_url,
|
||||||
|
res.status()
|
||||||
|
));
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
errors.push(format!("{} {} error: {}", worker_type, worker_url, e));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
(results, errors)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper to get worker URLs from a worker collection
|
||||||
|
fn get_worker_urls(
|
||||||
|
workers: &RwLock<Vec<Box<dyn Worker>>>,
|
||||||
|
worker_type: &str,
|
||||||
|
) -> Result<Vec<String>, String> {
|
||||||
|
workers
|
||||||
|
.read()
|
||||||
|
.map(|workers| {
|
||||||
|
workers
|
||||||
|
.iter()
|
||||||
|
.map(|w| w.url().to_string())
|
||||||
|
.collect::<Vec<_>>()
|
||||||
|
})
|
||||||
|
.map_err(|_| format!("Failed to access {} workers", worker_type))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generic helper for proxying requests to the first worker
|
||||||
|
async fn proxy_to_first_worker(
|
||||||
|
&self,
|
||||||
|
workers: &RwLock<Vec<Box<dyn Worker>>>,
|
||||||
|
endpoint: &str,
|
||||||
|
worker_type: &str,
|
||||||
|
headers: Option<Vec<(String, String)>>,
|
||||||
|
) -> Response {
|
||||||
|
// Get first worker URL to avoid holding lock across await
|
||||||
|
let first_worker_url = match workers.read() {
|
||||||
|
Ok(workers) => workers.first().map(|w| w.url().to_string()),
|
||||||
|
Err(_) => {
|
||||||
|
return (
|
||||||
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
|
format!("Failed to access {} workers", worker_type),
|
||||||
|
)
|
||||||
|
.into_response();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
if let Some(worker_url) = first_worker_url {
|
||||||
|
let url = format!("{}/{}", worker_url, endpoint);
|
||||||
|
let mut request_builder = self.client.get(&url);
|
||||||
|
|
||||||
|
// Add headers if provided
|
||||||
|
if let Some(headers) = headers {
|
||||||
|
for (name, value) in headers {
|
||||||
|
request_builder = request_builder.header(name, value);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
match request_builder.send().await {
|
||||||
|
Ok(res) if res.status().is_success() => match res.bytes().await {
|
||||||
|
Ok(body) => (StatusCode::OK, body).into_response(),
|
||||||
|
Err(e) => {
|
||||||
|
error!("Failed to read response body: {}", e);
|
||||||
|
(
|
||||||
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
|
format!("Failed to read response body: {}", e),
|
||||||
|
)
|
||||||
|
.into_response()
|
||||||
|
}
|
||||||
|
},
|
||||||
|
Ok(res) => {
|
||||||
|
let status = StatusCode::from_u16(res.status().as_u16())
|
||||||
|
.unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
|
||||||
|
(
|
||||||
|
status,
|
||||||
|
format!("{} server returned status: {}", worker_type, res.status()),
|
||||||
|
)
|
||||||
|
.into_response()
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
error!("Failed to proxy request to {} server: {}", worker_type, e);
|
||||||
|
(
|
||||||
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
|
format!("Failed to proxy request: {}", e),
|
||||||
|
)
|
||||||
|
.into_response()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
(
|
||||||
|
StatusCode::SERVICE_UNAVAILABLE,
|
||||||
|
format!("No {} servers available", worker_type),
|
||||||
|
)
|
||||||
|
.into_response()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub async fn add_prefill_server(
|
pub async fn add_prefill_server(
|
||||||
&self,
|
&self,
|
||||||
url: String,
|
url: String,
|
||||||
@@ -1384,191 +1516,32 @@ impl RouterTrait for PDRouter {
|
|||||||
|
|
||||||
async fn get_server_info(&self, _req: Request<Body>) -> Response {
|
async fn get_server_info(&self, _req: Request<Body>) -> Response {
|
||||||
// Get info from the first decode server to match sglang's server info format
|
// Get info from the first decode server to match sglang's server info format
|
||||||
let first_decode_url = if let Ok(workers) = self.decode_workers.read() {
|
// Note: We use decode workers for server info to match expected format
|
||||||
workers.first().map(|w| w.url().to_string())
|
self.proxy_to_first_worker(&self.decode_workers, "get_server_info", "decode", None)
|
||||||
} else {
|
.await
|
||||||
return (
|
|
||||||
StatusCode::INTERNAL_SERVER_ERROR,
|
|
||||||
"Failed to access decode workers",
|
|
||||||
)
|
|
||||||
.into_response();
|
|
||||||
};
|
|
||||||
|
|
||||||
if let Some(worker_url) = first_decode_url {
|
|
||||||
match self
|
|
||||||
.client
|
|
||||||
.get(format!("{}/get_server_info", worker_url))
|
|
||||||
.send()
|
|
||||||
.await
|
|
||||||
{
|
|
||||||
Ok(res) if res.status().is_success() => {
|
|
||||||
match res.json::<Value>().await {
|
|
||||||
Ok(info) => {
|
|
||||||
// The decode server should already return the proper format
|
|
||||||
// with tokenizer_path and other fields that bench_one_batch_server.py expects
|
|
||||||
Json(info).into_response()
|
|
||||||
}
|
|
||||||
Err(e) => {
|
|
||||||
error!("Failed to parse server info: {}", e);
|
|
||||||
(
|
|
||||||
StatusCode::INTERNAL_SERVER_ERROR,
|
|
||||||
format!("Failed to parse server info: {}", e),
|
|
||||||
)
|
|
||||||
.into_response()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Ok(res) => {
|
|
||||||
let status = StatusCode::from_u16(res.status().as_u16())
|
|
||||||
.unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
|
|
||||||
(
|
|
||||||
status,
|
|
||||||
format!("Decode server returned status: {}", res.status()),
|
|
||||||
)
|
|
||||||
.into_response()
|
|
||||||
}
|
|
||||||
Err(e) => {
|
|
||||||
error!("Failed to get server info: {}", e);
|
|
||||||
(
|
|
||||||
StatusCode::INTERNAL_SERVER_ERROR,
|
|
||||||
format!("Failed to get server info: {}", e),
|
|
||||||
)
|
|
||||||
.into_response()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
(
|
|
||||||
StatusCode::SERVICE_UNAVAILABLE,
|
|
||||||
"No decode servers available",
|
|
||||||
)
|
|
||||||
.into_response()
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn get_models(&self, req: Request<Body>) -> Response {
|
async fn get_models(&self, req: Request<Body>) -> Response {
|
||||||
// Extract headers first to avoid Send issues
|
// Extract headers first to avoid Send issues
|
||||||
let headers = crate::routers::router::copy_request_headers(&req);
|
let headers = crate::routers::router::copy_request_headers(&req);
|
||||||
|
|
||||||
// Get first prefill worker URL to avoid holding lock across await
|
// Proxy to first prefill worker
|
||||||
let first_worker_url = if let Ok(workers) = self.prefill_workers.read() {
|
self.proxy_to_first_worker(&self.prefill_workers, "v1/models", "prefill", Some(headers))
|
||||||
workers.first().map(|w| w.url().to_string())
|
.await
|
||||||
} else {
|
|
||||||
return (
|
|
||||||
StatusCode::INTERNAL_SERVER_ERROR,
|
|
||||||
"Failed to access prefill workers",
|
|
||||||
)
|
|
||||||
.into_response();
|
|
||||||
};
|
|
||||||
|
|
||||||
if let Some(worker_url) = first_worker_url {
|
|
||||||
let url = format!("{}/v1/models", worker_url);
|
|
||||||
let mut request_builder = self.client.get(&url);
|
|
||||||
|
|
||||||
// Add headers
|
|
||||||
for (name, value) in headers {
|
|
||||||
request_builder = request_builder.header(name, value);
|
|
||||||
}
|
|
||||||
|
|
||||||
match request_builder.send().await {
|
|
||||||
Ok(res) if res.status().is_success() => match res.bytes().await {
|
|
||||||
Ok(body) => (StatusCode::OK, body).into_response(),
|
|
||||||
Err(e) => {
|
|
||||||
error!("Failed to read response body: {}", e);
|
|
||||||
(
|
|
||||||
StatusCode::INTERNAL_SERVER_ERROR,
|
|
||||||
format!("Failed to read response body: {}", e),
|
|
||||||
)
|
|
||||||
.into_response()
|
|
||||||
}
|
|
||||||
},
|
|
||||||
Ok(res) => {
|
|
||||||
let status = StatusCode::from_u16(res.status().as_u16())
|
|
||||||
.unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
|
|
||||||
(
|
|
||||||
status,
|
|
||||||
format!("Prefill server returned status: {}", res.status()),
|
|
||||||
)
|
|
||||||
.into_response()
|
|
||||||
}
|
|
||||||
Err(e) => {
|
|
||||||
error!("Failed to get models: {}", e);
|
|
||||||
(
|
|
||||||
StatusCode::INTERNAL_SERVER_ERROR,
|
|
||||||
format!("Failed to get models: {}", e),
|
|
||||||
)
|
|
||||||
.into_response()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
(
|
|
||||||
StatusCode::SERVICE_UNAVAILABLE,
|
|
||||||
"No prefill servers available",
|
|
||||||
)
|
|
||||||
.into_response()
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn get_model_info(&self, req: Request<Body>) -> Response {
|
async fn get_model_info(&self, req: Request<Body>) -> Response {
|
||||||
// Extract headers first to avoid Send issues
|
// Extract headers first to avoid Send issues
|
||||||
let headers = crate::routers::router::copy_request_headers(&req);
|
let headers = crate::routers::router::copy_request_headers(&req);
|
||||||
|
|
||||||
// Get first prefill worker URL to avoid holding lock across await
|
// Proxy to first prefill worker
|
||||||
let first_worker_url = if let Ok(workers) = self.prefill_workers.read() {
|
self.proxy_to_first_worker(
|
||||||
workers.first().map(|w| w.url().to_string())
|
&self.prefill_workers,
|
||||||
} else {
|
"get_model_info",
|
||||||
return (
|
"prefill",
|
||||||
StatusCode::INTERNAL_SERVER_ERROR,
|
Some(headers),
|
||||||
"Failed to access prefill workers",
|
)
|
||||||
)
|
.await
|
||||||
.into_response();
|
|
||||||
};
|
|
||||||
|
|
||||||
if let Some(worker_url) = first_worker_url {
|
|
||||||
let url = format!("{}/get_model_info", worker_url);
|
|
||||||
let mut request_builder = self.client.get(&url);
|
|
||||||
|
|
||||||
// Add headers
|
|
||||||
for (name, value) in headers {
|
|
||||||
request_builder = request_builder.header(name, value);
|
|
||||||
}
|
|
||||||
|
|
||||||
match request_builder.send().await {
|
|
||||||
Ok(res) if res.status().is_success() => match res.bytes().await {
|
|
||||||
Ok(body) => (StatusCode::OK, body).into_response(),
|
|
||||||
Err(e) => {
|
|
||||||
error!("Failed to read response body: {}", e);
|
|
||||||
(
|
|
||||||
StatusCode::INTERNAL_SERVER_ERROR,
|
|
||||||
format!("Failed to read response body: {}", e),
|
|
||||||
)
|
|
||||||
.into_response()
|
|
||||||
}
|
|
||||||
},
|
|
||||||
Ok(res) => {
|
|
||||||
let status = StatusCode::from_u16(res.status().as_u16())
|
|
||||||
.unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
|
|
||||||
(
|
|
||||||
status,
|
|
||||||
format!("Prefill server returned status: {}", res.status()),
|
|
||||||
)
|
|
||||||
.into_response()
|
|
||||||
}
|
|
||||||
Err(e) => {
|
|
||||||
error!("Failed to get model info: {}", e);
|
|
||||||
(
|
|
||||||
StatusCode::INTERNAL_SERVER_ERROR,
|
|
||||||
format!("Failed to get model info: {}", e),
|
|
||||||
)
|
|
||||||
.into_response()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
(
|
|
||||||
StatusCode::SERVICE_UNAVAILABLE,
|
|
||||||
"No prefill servers available",
|
|
||||||
)
|
|
||||||
.into_response()
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn route_generate(
|
async fn route_generate(
|
||||||
@@ -1692,70 +1665,19 @@ impl RouterTrait for PDRouter {
|
|||||||
}
|
}
|
||||||
|
|
||||||
async fn flush_cache(&self) -> Response {
|
async fn flush_cache(&self) -> Response {
|
||||||
let mut results = Vec::new();
|
// Process both prefill and decode workers
|
||||||
let mut errors = Vec::new();
|
let (prefill_results, prefill_errors) = self
|
||||||
|
.process_workers(&self.prefill_workers, "Prefill", "flush_cache")
|
||||||
|
.await;
|
||||||
|
let (decode_results, decode_errors) = self
|
||||||
|
.process_workers(&self.decode_workers, "Decode", "flush_cache")
|
||||||
|
.await;
|
||||||
|
|
||||||
// Get prefill worker URLs first to avoid holding lock across await
|
// Combine results and errors
|
||||||
let prefill_urls = if let Ok(workers) = self.prefill_workers.read() {
|
let mut results = prefill_results;
|
||||||
workers
|
results.extend(decode_results);
|
||||||
.iter()
|
let mut errors = prefill_errors;
|
||||||
.map(|w| w.url().to_string())
|
errors.extend(decode_errors);
|
||||||
.collect::<Vec<_>>()
|
|
||||||
} else {
|
|
||||||
errors.push("Failed to access prefill workers".to_string());
|
|
||||||
Vec::new()
|
|
||||||
};
|
|
||||||
|
|
||||||
// Flush prefill workers
|
|
||||||
for worker_url in prefill_urls {
|
|
||||||
let url = format!("{}/flush_cache", worker_url);
|
|
||||||
match self.client.post(&url).send().await {
|
|
||||||
Ok(res) if res.status().is_success() => {
|
|
||||||
results.push(format!("Prefill {}: OK", worker_url));
|
|
||||||
}
|
|
||||||
Ok(res) => {
|
|
||||||
errors.push(format!(
|
|
||||||
"Prefill {} returned status: {}",
|
|
||||||
worker_url,
|
|
||||||
res.status()
|
|
||||||
));
|
|
||||||
}
|
|
||||||
Err(e) => {
|
|
||||||
errors.push(format!("Prefill {} error: {}", worker_url, e));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get decode worker URLs first to avoid holding lock across await
|
|
||||||
let decode_urls = if let Ok(workers) = self.decode_workers.read() {
|
|
||||||
workers
|
|
||||||
.iter()
|
|
||||||
.map(|w| w.url().to_string())
|
|
||||||
.collect::<Vec<_>>()
|
|
||||||
} else {
|
|
||||||
errors.push("Failed to access decode workers".to_string());
|
|
||||||
Vec::new()
|
|
||||||
};
|
|
||||||
|
|
||||||
// Flush decode workers
|
|
||||||
for worker_url in decode_urls {
|
|
||||||
let url = format!("{}/flush_cache", worker_url);
|
|
||||||
match self.client.post(&url).send().await {
|
|
||||||
Ok(res) if res.status().is_success() => {
|
|
||||||
results.push(format!("Decode {}: OK", worker_url));
|
|
||||||
}
|
|
||||||
Ok(res) => {
|
|
||||||
errors.push(format!(
|
|
||||||
"Decode {} returned status: {}",
|
|
||||||
worker_url,
|
|
||||||
res.status()
|
|
||||||
));
|
|
||||||
}
|
|
||||||
Err(e) => {
|
|
||||||
errors.push(format!("Decode {} error: {}", worker_url, e));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if errors.is_empty() {
|
if errors.is_empty() {
|
||||||
(
|
(
|
||||||
@@ -1779,50 +1701,38 @@ impl RouterTrait for PDRouter {
|
|||||||
let mut loads = HashMap::new();
|
let mut loads = HashMap::new();
|
||||||
let mut errors = Vec::new();
|
let mut errors = Vec::new();
|
||||||
|
|
||||||
// Get prefill worker URLs first to avoid holding lock across await
|
// Process prefill workers
|
||||||
let prefill_urls = if let Ok(workers) = self.prefill_workers.read() {
|
match Self::get_worker_urls(&self.prefill_workers, "prefill") {
|
||||||
workers
|
Ok(urls) => {
|
||||||
.iter()
|
for worker_url in urls {
|
||||||
.map(|w| w.url().to_string())
|
match get_worker_load(&self.client, &worker_url).await {
|
||||||
.collect::<Vec<_>>()
|
Some(load) => {
|
||||||
} else {
|
loads.insert(format!("prefill_{}", worker_url), load);
|
||||||
errors.push("Failed to access prefill workers".to_string());
|
}
|
||||||
Vec::new()
|
None => {
|
||||||
};
|
errors.push(format!("Failed to get load from prefill {}", worker_url));
|
||||||
|
}
|
||||||
// Get loads from prefill workers
|
}
|
||||||
for worker_url in prefill_urls {
|
|
||||||
match get_worker_load(&self.client, &worker_url).await {
|
|
||||||
Some(load) => {
|
|
||||||
loads.insert(format!("prefill_{}", worker_url), load);
|
|
||||||
}
|
|
||||||
None => {
|
|
||||||
errors.push(format!("Failed to get load from prefill {}", worker_url));
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Err(e) => errors.push(e),
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get decode worker URLs first to avoid holding lock across await
|
// Process decode workers
|
||||||
let decode_urls = if let Ok(workers) = self.decode_workers.read() {
|
match Self::get_worker_urls(&self.decode_workers, "decode") {
|
||||||
workers
|
Ok(urls) => {
|
||||||
.iter()
|
for worker_url in urls {
|
||||||
.map(|w| w.url().to_string())
|
match get_worker_load(&self.client, &worker_url).await {
|
||||||
.collect::<Vec<_>>()
|
Some(load) => {
|
||||||
} else {
|
loads.insert(format!("decode_{}", worker_url), load);
|
||||||
errors.push("Failed to access decode workers".to_string());
|
}
|
||||||
Vec::new()
|
None => {
|
||||||
};
|
errors.push(format!("Failed to get load from decode {}", worker_url));
|
||||||
|
}
|
||||||
// Get loads from decode workers
|
}
|
||||||
for worker_url in decode_urls {
|
|
||||||
match get_worker_load(&self.client, &worker_url).await {
|
|
||||||
Some(load) => {
|
|
||||||
loads.insert(format!("decode_{}", worker_url), load);
|
|
||||||
}
|
|
||||||
None => {
|
|
||||||
errors.push(format!("Failed to get load from decode {}", worker_url));
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Err(e) => errors.push(e),
|
||||||
}
|
}
|
||||||
|
|
||||||
let response_data = serde_json::json!({
|
let response_data = serde_json::json!({
|
||||||
|
|||||||
Reference in New Issue
Block a user