Merge PDLB (Prefill-Decode Load Balancer) into SGLang Router (#7096)
This commit is contained in:
@@ -1,10 +1,10 @@
|
||||
use crate::pd_router::PDRouter;
|
||||
use crate::pd_types::PDSelectionPolicy;
|
||||
use crate::tree::Tree;
|
||||
use ::metrics::{counter, gauge, histogram};
|
||||
use actix_web::http::header::{HeaderValue, CONTENT_TYPE};
|
||||
use actix_web::{HttpRequest, HttpResponse};
|
||||
use bytes::Bytes;
|
||||
use futures_util::{StreamExt, TryStreamExt};
|
||||
use serde_json::Value;
|
||||
use std::collections::HashMap;
|
||||
use std::fmt::Debug;
|
||||
use std::sync::atomic::AtomicUsize;
|
||||
@@ -15,7 +15,7 @@ use std::time::Instant;
|
||||
use tokio;
|
||||
use tracing::{debug, error, info, warn};
|
||||
|
||||
fn copy_request_headers(req: &HttpRequest) -> Vec<(String, String)> {
|
||||
pub fn copy_request_headers(req: &HttpRequest) -> Vec<(String, String)> {
|
||||
req.headers()
|
||||
.iter()
|
||||
.filter_map(|(name, value)| {
|
||||
@@ -40,6 +40,9 @@ pub enum Router {
|
||||
timeout_secs: u64,
|
||||
interval_secs: u64,
|
||||
},
|
||||
PrefillDecode {
|
||||
pd_router: Arc<PDRouter>,
|
||||
},
|
||||
CacheAware {
|
||||
/*
|
||||
Cache-Aware Load Balancing Router
|
||||
@@ -133,6 +136,13 @@ pub enum PolicyConfig {
|
||||
timeout_secs: u64,
|
||||
interval_secs: u64,
|
||||
},
|
||||
PrefillDecodeConfig {
|
||||
selection_policy: PDSelectionPolicy,
|
||||
prefill_urls: Vec<(String, Option<u16>)>, // (url, bootstrap_port)
|
||||
decode_urls: Vec<String>,
|
||||
timeout_secs: u64,
|
||||
interval_secs: u64,
|
||||
},
|
||||
}
|
||||
|
||||
impl Router {
|
||||
@@ -155,10 +165,24 @@ impl Router {
|
||||
interval_secs,
|
||||
..
|
||||
} => (*timeout_secs, *interval_secs),
|
||||
PolicyConfig::PrefillDecodeConfig {
|
||||
timeout_secs,
|
||||
interval_secs,
|
||||
..
|
||||
} => (*timeout_secs, *interval_secs),
|
||||
};
|
||||
|
||||
// Wait until all workers are healthy
|
||||
Self::wait_for_healthy_workers(&worker_urls, timeout_secs, interval_secs)?;
|
||||
// For PrefillDecode, we need to handle workers differently
|
||||
match &policy_config {
|
||||
PolicyConfig::PrefillDecodeConfig { .. } => {
|
||||
// PD mode doesn't use the worker_urls parameter
|
||||
// We'll validate PD workers separately
|
||||
}
|
||||
_ => {
|
||||
// Wait until all workers are healthy for regular modes
|
||||
Self::wait_for_healthy_workers(&worker_urls, timeout_secs, interval_secs)?;
|
||||
}
|
||||
}
|
||||
|
||||
// Create router based on policy...
|
||||
Ok(match policy_config {
|
||||
@@ -226,7 +250,7 @@ impl Router {
|
||||
});
|
||||
|
||||
for url in &worker_urls {
|
||||
tree.lock().unwrap().insert(&"".to_string(), url);
|
||||
tree.lock().unwrap().insert("", url);
|
||||
}
|
||||
|
||||
Router::CacheAware {
|
||||
@@ -242,6 +266,26 @@ impl Router {
|
||||
_eviction_thread: Some(eviction_thread),
|
||||
}
|
||||
}
|
||||
PolicyConfig::PrefillDecodeConfig {
|
||||
selection_policy,
|
||||
prefill_urls,
|
||||
decode_urls,
|
||||
timeout_secs,
|
||||
interval_secs,
|
||||
} => {
|
||||
// Create PDRouter instance
|
||||
let pd_router = PDRouter::new(
|
||||
prefill_urls,
|
||||
decode_urls,
|
||||
selection_policy,
|
||||
timeout_secs,
|
||||
interval_secs,
|
||||
)?;
|
||||
|
||||
Router::PrefillDecode {
|
||||
pd_router: Arc::new(pd_router),
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -251,16 +295,23 @@ impl Router {
|
||||
Router::RoundRobin { worker_urls, .. } => Arc::clone(worker_urls),
|
||||
Router::Random { worker_urls, .. } => Arc::clone(worker_urls),
|
||||
Router::CacheAware { worker_urls, .. } => Arc::clone(worker_urls),
|
||||
Router::PrefillDecode { .. } => {
|
||||
// For PD mode, return empty list since we manage workers differently
|
||||
Arc::new(RwLock::new(Vec::new()))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn wait_for_healthy_workers(
|
||||
pub fn wait_for_healthy_workers(
|
||||
worker_urls: &[String],
|
||||
timeout_secs: u64,
|
||||
interval_secs: u64,
|
||||
) -> Result<(), String> {
|
||||
let start_time = std::time::Instant::now();
|
||||
let sync_client = reqwest::blocking::Client::new();
|
||||
let sync_client = reqwest::blocking::Client::builder()
|
||||
.timeout(Duration::from_secs(timeout_secs))
|
||||
.build()
|
||||
.map_err(|e| format!("Failed to create HTTP client: {}", e))?;
|
||||
|
||||
loop {
|
||||
if start_time.elapsed() > Duration::from_secs(timeout_secs) {
|
||||
@@ -323,10 +374,14 @@ impl Router {
|
||||
Ok(worker_urls.read().unwrap()[0].clone())
|
||||
}
|
||||
}
|
||||
Router::PrefillDecode { .. } => {
|
||||
// For PD mode, we don't need this method as routing is handled by PDRouter
|
||||
Err("PrefillDecode mode doesn't use select_first_worker".to_string())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn send_request(
|
||||
pub async fn send_request(
|
||||
&self,
|
||||
client: &reqwest::Client,
|
||||
worker_url: &str,
|
||||
@@ -339,7 +394,11 @@ impl Router {
|
||||
// Copy all headers from original request except for /health because it does not need authorization
|
||||
if route != "/health" {
|
||||
for (name, value) in copy_request_headers(req) {
|
||||
request_builder = request_builder.header(name, value);
|
||||
// Skip Content-Type and Content-Length as .json() sets them
|
||||
if name.to_lowercase() != "content-type" && name.to_lowercase() != "content-length"
|
||||
{
|
||||
request_builder = request_builder.header(name, value);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -433,50 +492,193 @@ impl Router {
|
||||
HttpResponse::InternalServerError().body("All retry attempts failed")
|
||||
}
|
||||
|
||||
fn get_text_from_request(&self, body: &Bytes, route: &str) -> String {
|
||||
// Convert body to JSON
|
||||
let json: Value = match serde_json::from_slice(body) {
|
||||
Ok(j) => j,
|
||||
Err(_) => {
|
||||
warn!("Failed to parse JSON from request body.");
|
||||
return String::new();
|
||||
pub async fn route_to_all(
|
||||
&self,
|
||||
client: &reqwest::Client,
|
||||
route: &str,
|
||||
req: &HttpRequest,
|
||||
) -> HttpResponse {
|
||||
// Get all worker URLs based on router type
|
||||
let worker_urls = match self {
|
||||
Router::PrefillDecode { .. } => {
|
||||
// For PD mode, route_to_all is not supported directly
|
||||
// It should be handled by PDRouter if needed
|
||||
return HttpResponse::NotImplemented()
|
||||
.body("route_to_all not implemented for PrefillDecode mode");
|
||||
}
|
||||
_ => self.get_worker_urls().read().unwrap().clone(),
|
||||
};
|
||||
|
||||
match route {
|
||||
"/generate" => {
|
||||
// For /generate, always use the "text" field.
|
||||
match json.get("text").and_then(Value::as_str) {
|
||||
Some(text) => text.to_string(),
|
||||
None => {
|
||||
warn!("No 'text' field found in request body for route /generate.");
|
||||
String::new()
|
||||
}
|
||||
}
|
||||
// Send requests to all workers concurrently
|
||||
let mut tasks = Vec::new();
|
||||
for worker_url in &worker_urls {
|
||||
let mut request_builder = client.post(format!("{}{}", worker_url, route));
|
||||
|
||||
// Copy headers from original request
|
||||
for (name, value) in copy_request_headers(req) {
|
||||
request_builder = request_builder.header(name, value);
|
||||
}
|
||||
"/v1/chat/completions" | "/v1/completions" => {
|
||||
// For these routes, try "messages", then "prompt", then "text".
|
||||
if let Some(messages) = json.get("messages") {
|
||||
serde_json::to_string(messages).unwrap_or_default()
|
||||
} else if let Some(prompt) = json.get("prompt").and_then(Value::as_str) {
|
||||
prompt.to_string()
|
||||
} else {
|
||||
warn!("Failed to find 'messages', 'prompt' in request body.");
|
||||
String::new()
|
||||
}
|
||||
|
||||
tasks.push(request_builder.send());
|
||||
}
|
||||
|
||||
// Wait for all responses
|
||||
let results = futures_util::future::join_all(tasks).await;
|
||||
|
||||
// Check if all succeeded
|
||||
let all_success = results.iter().all(|r| {
|
||||
r.as_ref()
|
||||
.map(|res| res.status().is_success())
|
||||
.unwrap_or(false)
|
||||
});
|
||||
|
||||
if all_success {
|
||||
HttpResponse::Ok().body("Operation completed on all servers")
|
||||
} else {
|
||||
HttpResponse::InternalServerError().body("Operation failed on one or more servers")
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn get_all_loads(
|
||||
&self,
|
||||
client: &reqwest::Client,
|
||||
_req: &HttpRequest,
|
||||
) -> HttpResponse {
|
||||
// For PD mode, delegate to PDRouter
|
||||
match self {
|
||||
Router::PrefillDecode { pd_router } => {
|
||||
return pd_router.get_loads(client).await;
|
||||
}
|
||||
_ => {
|
||||
warn!("Unknown route: {} - defaulting to fallback string", route);
|
||||
String::new()
|
||||
// For non-PD routers, handle normally
|
||||
}
|
||||
}
|
||||
|
||||
let urls = self.get_worker_urls().read().unwrap().clone();
|
||||
let prefill_urls: Vec<String> = Vec::new();
|
||||
let decode_urls = urls;
|
||||
|
||||
// Collect loads from all servers
|
||||
let mut prefill_loads = Vec::new();
|
||||
let mut decode_loads = Vec::new();
|
||||
|
||||
// Get prefill loads
|
||||
for url in &prefill_urls {
|
||||
let load = self.get_worker_load(client, url).await.unwrap_or(-1);
|
||||
prefill_loads.push(serde_json::json!({
|
||||
"engine": format!("(Prefill@{})", url),
|
||||
"load": load as i64
|
||||
}));
|
||||
}
|
||||
|
||||
// Get decode loads
|
||||
for url in &decode_urls {
|
||||
let load = self.get_worker_load(client, url).await.unwrap_or(-1);
|
||||
decode_loads.push(serde_json::json!({
|
||||
"engine": format!("(Decode@{})", url),
|
||||
"load": load as i64
|
||||
}));
|
||||
}
|
||||
|
||||
HttpResponse::Ok().json(serde_json::json!({
|
||||
"prefill": prefill_loads,
|
||||
"decode": decode_loads
|
||||
}))
|
||||
}
|
||||
|
||||
// New method to route typed requests directly
|
||||
pub async fn route_typed_request<
|
||||
T: crate::openai_api_types::GenerationRequest + serde::Serialize + Clone,
|
||||
>(
|
||||
&self,
|
||||
client: &reqwest::Client,
|
||||
req: &HttpRequest,
|
||||
typed_req: &T,
|
||||
route: &str,
|
||||
) -> HttpResponse {
|
||||
match self {
|
||||
Router::PrefillDecode { .. } => HttpResponse::InternalServerError()
|
||||
.body("PD routing should use specialized typed handlers"),
|
||||
_ => {
|
||||
// Handle retries like the original implementation
|
||||
let start = Instant::now();
|
||||
const MAX_REQUEST_RETRIES: u32 = 3;
|
||||
const MAX_TOTAL_RETRIES: u32 = 6;
|
||||
let mut total_retries = 0;
|
||||
|
||||
while total_retries < MAX_TOTAL_RETRIES {
|
||||
// Extract routing text directly from typed request
|
||||
let text = typed_req.extract_text_for_routing();
|
||||
let is_stream = typed_req.is_stream();
|
||||
|
||||
// Select worker based on text
|
||||
let worker_url = self.select_generate_worker_from_text(&text);
|
||||
let mut request_retries = 0;
|
||||
|
||||
// Try the same worker multiple times
|
||||
while request_retries < MAX_REQUEST_RETRIES {
|
||||
if total_retries >= 1 {
|
||||
info!("Retrying request after {} failed attempts", total_retries);
|
||||
counter!("sgl_router_retries_total", "route" => route.to_string())
|
||||
.increment(1);
|
||||
}
|
||||
|
||||
// Send typed request directly
|
||||
let response = self
|
||||
.send_typed_request(
|
||||
client,
|
||||
req,
|
||||
typed_req,
|
||||
route,
|
||||
&worker_url,
|
||||
is_stream,
|
||||
)
|
||||
.await;
|
||||
|
||||
if response.status().is_success() {
|
||||
let duration = start.elapsed();
|
||||
histogram!("sgl_router_generate_duration_seconds", "route" => route.to_string())
|
||||
.record(duration.as_secs_f64());
|
||||
return response;
|
||||
} else {
|
||||
// if the worker is healthy, it means the request is bad, so return the error response
|
||||
let health_response =
|
||||
self.send_request(client, &worker_url, "/health", req).await;
|
||||
if health_response.status().is_success() {
|
||||
counter!("sgl_router_request_errors_total", "route" => route.to_string())
|
||||
.increment(1);
|
||||
return response;
|
||||
}
|
||||
}
|
||||
|
||||
warn!(
|
||||
"Generate request to {} failed (attempt {}/{})",
|
||||
worker_url,
|
||||
request_retries + 1,
|
||||
MAX_REQUEST_RETRIES
|
||||
);
|
||||
|
||||
request_retries += 1;
|
||||
total_retries += 1;
|
||||
|
||||
if request_retries == MAX_REQUEST_RETRIES {
|
||||
warn!("Removing failed worker: {}", worker_url);
|
||||
self.remove_worker(&worker_url);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
counter!("sgl_router_request_errors_total", "route" => route.to_string())
|
||||
.increment(1);
|
||||
HttpResponse::InternalServerError().body("All retry attempts failed")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: return Result<String, String> instead of panicking
|
||||
fn select_generate_worker(&self, body: &Bytes, route: &str) -> String {
|
||||
let text = self.get_text_from_request(&body, route);
|
||||
|
||||
let worker_url = match self {
|
||||
// Helper method to select worker from text
|
||||
fn select_generate_worker_from_text(&self, text: &str) -> String {
|
||||
match self {
|
||||
Router::RoundRobin {
|
||||
worker_urls,
|
||||
current_index,
|
||||
@@ -506,8 +708,6 @@ impl Router {
|
||||
balance_rel_threshold,
|
||||
..
|
||||
} => {
|
||||
// TODO: delay scheduling if cache hit rate is high because it may cause imbalance. prioritize low hit rate ones
|
||||
|
||||
let tree = tree.lock().unwrap();
|
||||
let mut running_queue = running_queue.lock().unwrap();
|
||||
|
||||
@@ -572,35 +772,48 @@ impl Router {
|
||||
|
||||
selected_url
|
||||
}
|
||||
};
|
||||
|
||||
worker_url
|
||||
Router::PrefillDecode { .. } => {
|
||||
// For PD mode, we don't use this method
|
||||
return "PD_MODE_ERROR".to_string();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn send_generate_request(
|
||||
// Send typed request directly without conversion
|
||||
async fn send_typed_request<T: serde::Serialize>(
|
||||
&self,
|
||||
client: &reqwest::Client,
|
||||
req: &HttpRequest,
|
||||
body: &Bytes,
|
||||
typed_req: &T,
|
||||
route: &str,
|
||||
worker_url: &str,
|
||||
is_stream: bool,
|
||||
) -> HttpResponse {
|
||||
let is_stream = serde_json::from_slice::<serde_json::Value>(&body)
|
||||
.map(|v| v.get("stream").and_then(|s| s.as_bool()).unwrap_or(false))
|
||||
.unwrap_or(false);
|
||||
let start = Instant::now();
|
||||
|
||||
// Debug: Log what we're sending
|
||||
if let Ok(json_str) = serde_json::to_string_pretty(typed_req) {
|
||||
debug!("Sending request to {}: {}", route, json_str);
|
||||
}
|
||||
|
||||
let mut request_builder = client
|
||||
.post(format!("{}{}", worker_url, route))
|
||||
.body(body.to_vec());
|
||||
.json(typed_req); // Use json() directly with typed request
|
||||
|
||||
// Copy all headers from original request
|
||||
for (name, value) in copy_request_headers(req) {
|
||||
request_builder = request_builder.header(name, value);
|
||||
// Skip Content-Type and Content-Length as .json() sets them
|
||||
if name.to_lowercase() != "content-type" && name.to_lowercase() != "content-length" {
|
||||
request_builder = request_builder.header(&name, &value);
|
||||
}
|
||||
}
|
||||
|
||||
let res = match request_builder.send().await {
|
||||
Ok(res) => res,
|
||||
Err(_) => return HttpResponse::InternalServerError().finish(),
|
||||
Err(e) => {
|
||||
error!("Failed to send request to {}: {}", worker_url, e);
|
||||
return HttpResponse::InternalServerError().body(format!("Request failed: {}", e));
|
||||
}
|
||||
};
|
||||
|
||||
let status = actix_web::http::StatusCode::from_u16(res.status().as_u16())
|
||||
@@ -625,6 +838,12 @@ impl Router {
|
||||
}
|
||||
}
|
||||
|
||||
// Record metrics
|
||||
let duration = start.elapsed();
|
||||
histogram!("sgl_router_generate_duration_seconds", "route" => route.to_string())
|
||||
.record(duration.as_secs_f64());
|
||||
counter!("sgl_router_requests_total", "route" => route.to_string()).increment(1);
|
||||
|
||||
response
|
||||
} else if let Router::CacheAware { running_queue, .. } = self {
|
||||
let running_queue = Arc::clone(running_queue);
|
||||
@@ -660,70 +879,6 @@ impl Router {
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn route_generate_request(
|
||||
&self,
|
||||
client: &reqwest::Client,
|
||||
req: &HttpRequest,
|
||||
body: &Bytes,
|
||||
route: &str,
|
||||
) -> HttpResponse {
|
||||
let start = Instant::now();
|
||||
const MAX_REQUEST_RETRIES: u32 = 3;
|
||||
const MAX_TOTAL_RETRIES: u32 = 6;
|
||||
let mut total_retries = 0;
|
||||
|
||||
while total_retries < MAX_TOTAL_RETRIES {
|
||||
let worker_url = self.select_generate_worker(body, route);
|
||||
let mut request_retries = 0;
|
||||
|
||||
// Try the same worker multiple times
|
||||
while request_retries < MAX_REQUEST_RETRIES {
|
||||
if total_retries >= 1 {
|
||||
info!("Retrying request after {} failed attempts", total_retries);
|
||||
counter!("sgl_router_retries_total", "route" => route.to_string()).increment(1);
|
||||
}
|
||||
|
||||
let response = self
|
||||
.send_generate_request(client, req, body, route, &worker_url)
|
||||
.await;
|
||||
|
||||
if response.status().is_success() {
|
||||
let duration = start.elapsed();
|
||||
histogram!("sgl_router_generate_duration_seconds", "route" => route.to_string()).record(duration.as_secs_f64());
|
||||
return response;
|
||||
} else {
|
||||
// if the worker is healthy, it means the request is bad, so return the error response
|
||||
let health_response =
|
||||
self.send_request(client, &worker_url, "/health", req).await;
|
||||
if health_response.status().is_success() {
|
||||
counter!("sgl_router_request_errors_total", "route" => route.to_string())
|
||||
.increment(1);
|
||||
return response;
|
||||
}
|
||||
}
|
||||
|
||||
warn!(
|
||||
"Generate request to {} failed (attempt {}/{})",
|
||||
worker_url,
|
||||
request_retries + 1,
|
||||
MAX_REQUEST_RETRIES
|
||||
);
|
||||
|
||||
request_retries += 1;
|
||||
total_retries += 1;
|
||||
|
||||
if request_retries == MAX_REQUEST_RETRIES {
|
||||
warn!("Removing failed worker: {}", worker_url);
|
||||
self.remove_worker(&worker_url);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
counter!("sgl_router_request_errors_total", "route" => route.to_string()).increment(1);
|
||||
HttpResponse::InternalServerError().body("All retry attempts failed")
|
||||
}
|
||||
|
||||
pub async fn add_worker(&self, worker_url: &str) -> Result<String, String> {
|
||||
let (timeout_secs, interval_secs) = match self {
|
||||
Router::Random {
|
||||
@@ -741,10 +896,17 @@ impl Router {
|
||||
interval_secs,
|
||||
..
|
||||
} => (*timeout_secs, *interval_secs),
|
||||
Router::PrefillDecode { .. } => {
|
||||
// For PD mode, we don't support adding workers via this method
|
||||
return Err("Adding workers to PrefillDecode router not supported via add_worker. Use dedicated PD management methods.".to_string());
|
||||
}
|
||||
};
|
||||
|
||||
let start_time = std::time::Instant::now();
|
||||
let client = reqwest::Client::new();
|
||||
let client = reqwest::Client::builder()
|
||||
.timeout(Duration::from_secs(timeout_secs))
|
||||
.build()
|
||||
.map_err(|e| format!("Failed to create HTTP client: {}", e))?;
|
||||
|
||||
loop {
|
||||
if start_time.elapsed() > Duration::from_secs(timeout_secs) {
|
||||
@@ -774,6 +936,9 @@ impl Router {
|
||||
urls.push(worker_url.to_string());
|
||||
gauge!("sgl_router_active_workers").set(urls.len() as f64);
|
||||
}
|
||||
Router::PrefillDecode { .. } => {
|
||||
return Err("Adding workers to PrefillDecode router not supported via add_worker. Use dedicated PD management methods.".to_string());
|
||||
}
|
||||
}
|
||||
|
||||
// If cache aware, initialize the queues for the new worker
|
||||
@@ -797,7 +962,7 @@ impl Router {
|
||||
.insert(worker_url.to_string(), 0);
|
||||
|
||||
// Add worker to tree
|
||||
tree.lock().unwrap().insert(&"".to_string(), &worker_url);
|
||||
tree.lock().unwrap().insert("", worker_url);
|
||||
}
|
||||
|
||||
return Ok(format!("Successfully added worker: {}", worker_url));
|
||||
@@ -850,6 +1015,10 @@ impl Router {
|
||||
return;
|
||||
}
|
||||
}
|
||||
Router::PrefillDecode { .. } => {
|
||||
warn!("Removing workers from PrefillDecode router not supported via remove_worker. Use dedicated PD management methods.");
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
// if cache aware, remove the worker from the tree
|
||||
@@ -875,4 +1044,133 @@ impl Router {
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
async fn get_worker_load(&self, client: &reqwest::Client, worker_url: &str) -> Option<isize> {
|
||||
match client.get(&format!("{}/get_load", worker_url)).send().await {
|
||||
Ok(res) if res.status().is_success() => match res.bytes().await {
|
||||
Ok(bytes) => match serde_json::from_slice::<serde_json::Value>(&bytes) {
|
||||
Ok(data) => data
|
||||
.get("load")
|
||||
.and_then(|v| v.as_i64())
|
||||
.map(|v| v as isize),
|
||||
Err(e) => {
|
||||
debug!("Failed to parse load response from {}: {}", worker_url, e);
|
||||
None
|
||||
}
|
||||
},
|
||||
Err(e) => {
|
||||
debug!("Failed to read load response from {}: {}", worker_url, e);
|
||||
None
|
||||
}
|
||||
},
|
||||
Ok(res) => {
|
||||
debug!(
|
||||
"Worker {} returned non-success status: {}",
|
||||
worker_url,
|
||||
res.status()
|
||||
);
|
||||
None
|
||||
}
|
||||
Err(e) => {
|
||||
debug!("Failed to get load from {}: {}", worker_url, e);
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// PD-specific wrapper methods that delegate to PDRouter
|
||||
pub async fn route_pd_health_generate(
|
||||
&self,
|
||||
_client: &reqwest::Client,
|
||||
_req: &HttpRequest,
|
||||
) -> HttpResponse {
|
||||
match self {
|
||||
Router::PrefillDecode { pd_router } => {
|
||||
pd_router.health_generate(&pd_router.http_client).await
|
||||
}
|
||||
_ => HttpResponse::InternalServerError().body("Not in PrefillDecode mode"),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn route_pd_generate_typed(
|
||||
&self,
|
||||
_client: &reqwest::Client,
|
||||
req: &HttpRequest,
|
||||
typed_req: crate::pd_types::GenerateReqInput,
|
||||
route: &str,
|
||||
) -> HttpResponse {
|
||||
match self {
|
||||
Router::PrefillDecode { pd_router } => {
|
||||
pd_router
|
||||
.route_generate(&pd_router.http_client, req, typed_req, route)
|
||||
.await
|
||||
}
|
||||
_ => HttpResponse::InternalServerError().body("Not in PrefillDecode mode"),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn route_pd_chat_typed(
|
||||
&self,
|
||||
_client: &reqwest::Client,
|
||||
req: &HttpRequest,
|
||||
typed_req: crate::pd_types::ChatReqInput,
|
||||
route: &str,
|
||||
) -> HttpResponse {
|
||||
match self {
|
||||
Router::PrefillDecode { pd_router } => {
|
||||
pd_router
|
||||
.route_chat(&pd_router.http_client, req, typed_req, route)
|
||||
.await
|
||||
}
|
||||
_ => HttpResponse::InternalServerError().body("Not in PrefillDecode mode"),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn get_pd_server_info(
|
||||
&self,
|
||||
_client: &reqwest::Client,
|
||||
_req: &HttpRequest,
|
||||
) -> HttpResponse {
|
||||
match self {
|
||||
Router::PrefillDecode { pd_router } => {
|
||||
pd_router.get_server_info(&pd_router.http_client).await
|
||||
}
|
||||
_ => HttpResponse::InternalServerError().body("Not in PrefillDecode mode"),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn get_pd_models(
|
||||
&self,
|
||||
_client: &reqwest::Client,
|
||||
req: &HttpRequest,
|
||||
) -> HttpResponse {
|
||||
match self {
|
||||
Router::PrefillDecode { pd_router } => {
|
||||
pd_router.get_models(&pd_router.http_client, req).await
|
||||
}
|
||||
_ => HttpResponse::InternalServerError().body("Not in PrefillDecode mode"),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn route_pd_flush_cache(&self, _client: &reqwest::Client) -> HttpResponse {
|
||||
match self {
|
||||
Router::PrefillDecode { pd_router } => {
|
||||
pd_router.flush_cache(&pd_router.http_client).await
|
||||
}
|
||||
_ => HttpResponse::InternalServerError().body("Not in PrefillDecode mode"),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn get_pd_model_info(
|
||||
&self,
|
||||
_client: &reqwest::Client,
|
||||
req: &HttpRequest,
|
||||
) -> HttpResponse {
|
||||
match self {
|
||||
Router::PrefillDecode { pd_router } => {
|
||||
pd_router.get_model_info(&pd_router.http_client, req).await
|
||||
}
|
||||
_ => HttpResponse::InternalServerError().body("Not in PrefillDecode mode"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user