diff --git a/sgl-pdlb/.rustfmt.toml b/sgl-pdlb/.rustfmt.toml deleted file mode 100644 index 745fb75b4..000000000 --- a/sgl-pdlb/.rustfmt.toml +++ /dev/null @@ -1,2 +0,0 @@ -reorder_imports = true -reorder_modules = true diff --git a/sgl-pdlb/Cargo.toml b/sgl-pdlb/Cargo.toml deleted file mode 100644 index bcfe8e1de..000000000 --- a/sgl-pdlb/Cargo.toml +++ /dev/null @@ -1,28 +0,0 @@ -[package] -edition = "2024" -name = "sgl-pdlb" -version = "0.1.0" - -[lib] -crate-type = ["cdylib", "rlib"] -name = "sgl_pdlb_rs" - -[dependencies] -actix-web = "4.11" -bytes = "1.8.0" -chrono = "0.4.38" -clap = { version = "4.4", features = ["derive"] } -dashmap = "6.1.0" -env_logger = "0.11.5" -futures = "0.3" -futures-util = "0.3" -http = "1.3.1" -log = "0.4.22" -pyo3 = { version = "0.25.0", features = ["extension-module"] } -rand = "0.9.0" -reqwest = { version = "0.12.8", features = ["stream", "blocking", "json"] } -serde = { version = "1.0", features = ["derive"] } -serde_json = "1.0" -tokio = { version = "1.34", features = ["full"] } -anyhow = "1.0.98" -typetag = "0.2.20" diff --git a/sgl-pdlb/README.md b/sgl-pdlb/README.md deleted file mode 100644 index c763ed501..000000000 --- a/sgl-pdlb/README.md +++ /dev/null @@ -1,12 +0,0 @@ -### Install dependencies - -```bash -pip install "maturin[patchelf]" -``` - -### Build and install - -```bash -maturin develop -pip install -e . -``` diff --git a/sgl-pdlb/py_src/sgl_pdlb/__init__.py b/sgl-pdlb/py_src/sgl_pdlb/__init__.py deleted file mode 100644 index f102a9cad..000000000 --- a/sgl-pdlb/py_src/sgl_pdlb/__init__.py +++ /dev/null @@ -1 +0,0 @@ -__version__ = "0.0.1" diff --git a/sgl-pdlb/pyproject.toml b/sgl-pdlb/pyproject.toml deleted file mode 100644 index 4a0f80b75..000000000 --- a/sgl-pdlb/pyproject.toml +++ /dev/null @@ -1,14 +0,0 @@ -[build-system] -requires = ["maturin>=1.8.0"] -build-backend = "maturin" - -[project] -name = "sgl_pdlb" -version = "0.0.1" - -[tool.maturin] -python-source = "py_src" -module-name = "sgl_pdlb._rust" - -[tool.maturin.build-backend] -features = ["pyo3/extension-module"] diff --git a/sgl-pdlb/src/io_struct.rs b/sgl-pdlb/src/io_struct.rs deleted file mode 100644 index 804aca58b..000000000 --- a/sgl-pdlb/src/io_struct.rs +++ /dev/null @@ -1,133 +0,0 @@ -use crate::strategy_lb::EngineInfo; -use serde::{Deserialize, Serialize}; -use serde_json::Value; - -#[derive(Debug, Deserialize, Serialize)] -#[serde(untagged)] -pub enum SingleOrBatch { - Single(T), - Batch(Vec), -} - -pub type InputIds = SingleOrBatch>; -pub type InputText = SingleOrBatch; -pub type BootstrapHost = SingleOrBatch; -pub type BootstrapPort = SingleOrBatch>; -pub type BootstrapRoom = SingleOrBatch; - -#[typetag::serde(tag = "type")] -pub trait Bootstrap { - fn is_stream(&self) -> bool; - fn get_batch_size(&self) -> Result, actix_web::Error>; - fn set_bootstrap_info( - &mut self, - bootstrap_host: BootstrapHost, - bootstrap_port: BootstrapPort, - bootstrap_room: BootstrapRoom, - ); - - fn add_bootstrap_info(&mut self, prefill_info: &EngineInfo) -> Result<(), actix_web::Error> { - let batch_size = self.get_batch_size()?; - if let Some(batch_size) = batch_size { - self.set_bootstrap_info( - BootstrapHost::Batch(vec![prefill_info.get_hostname(); batch_size]), - BootstrapPort::Batch(vec![prefill_info.bootstrap_port; batch_size]), - BootstrapRoom::Batch((0..batch_size).map(|_| rand::random::()).collect()), - ); - } else { - self.set_bootstrap_info( - BootstrapHost::Single(prefill_info.get_hostname()), - BootstrapPort::Single(prefill_info.bootstrap_port), - BootstrapRoom::Single(rand::random::()), - ); - } - Ok(()) - } -} - -#[derive(Debug, Deserialize, Serialize)] -pub struct GenerateReqInput { - pub text: Option, - pub input_ids: Option, - #[serde(default)] - pub stream: bool, - pub bootstrap_host: Option, - pub bootstrap_port: Option, - pub bootstrap_room: Option, - - #[serde(flatten)] - pub other: Value, -} - -impl GenerateReqInput { - pub fn get_batch_size(&self) -> Result, actix_web::Error> { - if self.text.is_some() && self.input_ids.is_some() { - return Err(actix_web::error::ErrorBadRequest( - "Both text and input_ids are present in the request".to_string(), - )); - } - if let Some(InputText::Batch(texts)) = &self.text { - return Ok(Some(texts.len())); - } - if let Some(InputIds::Batch(ids)) = &self.input_ids { - return Ok(Some(ids.len())); - } - Ok(None) - } -} - -#[typetag::serde] -impl Bootstrap for GenerateReqInput { - fn is_stream(&self) -> bool { - self.stream - } - - fn get_batch_size(&self) -> Result, actix_web::Error> { - self.get_batch_size() - } - - fn set_bootstrap_info( - &mut self, - bootstrap_host: BootstrapHost, - bootstrap_port: BootstrapPort, - bootstrap_room: BootstrapRoom, - ) { - self.bootstrap_host = Some(bootstrap_host); - self.bootstrap_port = Some(bootstrap_port); - self.bootstrap_room = Some(bootstrap_room); - } -} - -#[derive(Debug, Deserialize, Serialize)] -pub struct ChatReqInput { - #[serde(default)] - pub stream: bool, - pub bootstrap_host: Option, - pub bootstrap_port: Option, - pub bootstrap_room: Option, - - #[serde(flatten)] - pub other: Value, -} - -#[typetag::serde] -impl Bootstrap for ChatReqInput { - fn is_stream(&self) -> bool { - self.stream - } - - fn get_batch_size(&self) -> Result, actix_web::Error> { - Ok(None) - } - - fn set_bootstrap_info( - &mut self, - bootstrap_host: BootstrapHost, - bootstrap_port: BootstrapPort, - bootstrap_room: BootstrapRoom, - ) { - self.bootstrap_host = Some(bootstrap_host); - self.bootstrap_port = Some(bootstrap_port); - self.bootstrap_room = Some(bootstrap_room); - } -} diff --git a/sgl-pdlb/src/lb_state.rs b/sgl-pdlb/src/lb_state.rs deleted file mode 100644 index 075ce43ef..000000000 --- a/sgl-pdlb/src/lb_state.rs +++ /dev/null @@ -1,175 +0,0 @@ -use crate::io_struct::Bootstrap; -use crate::strategy_lb::{EngineInfo, EngineLoad, EngineType, LBPolicy, StrategyLB}; -use actix_web::HttpResponse; -use bytes::Bytes; -use futures::{Stream, StreamExt, future::join_all}; -use reqwest::{Method, StatusCode}; -use std::pin::Pin; - -pub enum ProxyResponseBody { - Full(Bytes), - Stream(Pin> + Send>>), -} - -pub struct ProxyResponse { - pub status: StatusCode, - pub body: ProxyResponseBody, -} - -impl ProxyResponse { - pub fn to_json(&self) -> Result { - match &self.body { - ProxyResponseBody::Full(body) => Ok(serde_json::from_slice(&body)?), - ProxyResponseBody::Stream(_) => Err(actix_web::error::ErrorBadRequest( - "Stream response is not supported", - )), - } - } -} - -impl Into> for ProxyResponse { - fn into(self) -> Result { - let status = actix_web::http::StatusCode::from_u16(self.status.as_u16()).map_err(|e| { - actix_web::error::ErrorBadGateway(format!("Invalid status code: {}", e)) - })?; - match self.body { - ProxyResponseBody::Full(body) => Ok(HttpResponse::Ok().status(status).body(body)), - ProxyResponseBody::Stream(body) => Ok(HttpResponse::Ok() - .status(status) - .content_type("application/octet-stream") - .streaming(body)), - } - } -} - -#[derive(Debug, Clone)] -pub struct LBConfig { - pub host: String, - pub port: u16, - pub policy: String, - pub prefill_infos: Vec<(String, Option)>, - pub decode_infos: Vec, - pub log_interval: u64, - pub timeout: u64, -} - -#[derive(Debug, Clone)] -pub struct LBState { - pub strategy_lb: StrategyLB, - pub client: reqwest::Client, - pub log_interval: u64, -} - -impl LBState { - pub fn new(lb_config: LBConfig) -> anyhow::Result { - let client = reqwest::Client::builder() - .timeout(std::time::Duration::from_secs(lb_config.timeout)) - .build()?; - let policy = match lb_config.policy.as_str() { - "random" => LBPolicy::Random, - "po2" => LBPolicy::PowerOfTwo, - _ => anyhow::bail!("Invalid policy"), - }; - let prefill_servers = lb_config - .prefill_infos - .into_iter() - .map(|(url, port)| EngineInfo::new_prefill(url, port)) - .collect(); - let decode_servers = lb_config - .decode_infos - .into_iter() - .map(|url| EngineInfo::new_decode(url)) - .collect(); - let lb = StrategyLB::new(policy, prefill_servers, decode_servers); - Ok(Self { - strategy_lb: lb, - client, - log_interval: lb_config.log_interval, - }) - } - - pub async fn route_one( - &self, - engine_info: &EngineInfo, - method: Method, - api_path: &str, - request: Option<&serde_json::Value>, - stream: bool, - ) -> Result { - let url = engine_info.api_path(api_path); - let request = request.unwrap_or(&serde_json::Value::Null); - let task = self.client.request(method, url).json(request).send(); - let resp = task.await.map_err(actix_web::error::ErrorBadGateway)?; - // FIXME: handle error status code (map status code to error) - let status = resp.status(); - let body = if stream { - let resp_stream = resp.bytes_stream().map(|r| { - r.map_err(actix_web::error::ErrorBadGateway) - .map(Bytes::from) - }); - ProxyResponseBody::Stream(Box::pin(resp_stream)) - } else { - let body = resp - .bytes() - .await - .map_err(actix_web::error::ErrorBadGateway)?; - ProxyResponseBody::Full(body) - }; - Ok(ProxyResponse { status, body }) - } - - pub async fn route_collect( - &self, - engines: &Vec, - method: Method, - api_path: &str, - request: Option<&serde_json::Value>, - ) -> Result, actix_web::Error> { - let tasks = engines - .iter() - .map(|engine| self.route_one(engine, method.clone(), api_path, request, false)); - let responses = join_all(tasks).await; - responses - .into_iter() - .map(|r| r.map_err(actix_web::error::ErrorBadGateway)) - .collect() - } - - pub async fn generate( - &self, - api_path: &str, - mut req: Box, - ) -> Result { - let (prefill, decode) = self.strategy_lb.select_pair(&self.client).await; - let stream = req.is_stream(); - req.add_bootstrap_info(&prefill)?; - let json = serde_json::to_value(req)?; - let prefill_task = self.route_one(&prefill, Method::POST, api_path, Some(&json), false); - let decode_task = self.route_one(&decode, Method::POST, api_path, Some(&json), stream); - let (_, decode_response) = tokio::join!(prefill_task, decode_task); - decode_response?.into() - } - - pub async fn get_engine_loads( - &self, - ) -> Result<(Vec, Vec), actix_web::Error> { - let servers = self.strategy_lb.get_all_servers(); - let responses = self - .route_collect(&servers, Method::GET, "/get_load", None) - .await?; - let loads = responses - .into_iter() - .enumerate() - .map(|(i, r)| Ok(EngineLoad::from_json(&servers[i], &r.to_json()?))) - .collect::, actix_web::Error>>()?; - let mut prefill_loads = Vec::new(); - let mut decode_loads = Vec::new(); - for load in loads { - match load.engine_info.engine_type { - EngineType::Prefill => prefill_loads.push(load), - EngineType::Decode => decode_loads.push(load), - } - } - Ok((prefill_loads, decode_loads)) - } -} diff --git a/sgl-pdlb/src/lib.rs b/sgl-pdlb/src/lib.rs deleted file mode 100644 index 097b86aca..000000000 --- a/sgl-pdlb/src/lib.rs +++ /dev/null @@ -1,68 +0,0 @@ -pub mod io_struct; -pub mod lb_state; -pub mod server; -pub mod strategy_lb; -use pyo3::{exceptions::PyRuntimeError, prelude::*}; - -use lb_state::{LBConfig, LBState}; -use server::{periodic_logging, startup}; -use tokio::signal; - -#[pyclass] -pub struct LoadBalancer { - lb_config: LBConfig, -} - -#[pymethods] -impl LoadBalancer { - #[new] - pub fn new( - host: String, - port: u16, - policy: String, - prefill_infos: Vec<(String, Option)>, - decode_infos: Vec, - log_interval: u64, - timeout: u64, - ) -> PyResult { - let lb_config = LBConfig { - host, - port, - policy, - prefill_infos, - decode_infos, - log_interval, - timeout, - }; - Ok(LoadBalancer { lb_config }) - } - - pub fn start(&self) -> PyResult<()> { - let lb_state = LBState::new(self.lb_config.clone()).map_err(|e| { - PyRuntimeError::new_err(format!("Failed to build load balancer: {}", e)) - })?; - - let ret: PyResult<()> = actix_web::rt::System::new().block_on(async move { - tokio::select! { - _ = periodic_logging(lb_state.clone()) => { - unreachable!() - } - res = startup(self.lb_config.clone(), lb_state) => { - res.map_err(|e| PyRuntimeError::new_err(e.to_string()))?; - unreachable!() - } - _ = signal::ctrl_c() => { - println!("Received Ctrl+C, shutting down"); - std::process::exit(0); - } - } - }); - ret - } -} - -#[pymodule] -fn _rust(_py: Python, m: &Bound) -> PyResult<()> { - m.add_class::()?; - Ok(()) -} diff --git a/sgl-pdlb/src/main.rs b/sgl-pdlb/src/main.rs deleted file mode 100644 index 7125aa4ed..000000000 --- a/sgl-pdlb/src/main.rs +++ /dev/null @@ -1,46 +0,0 @@ -mod io_struct; -mod lb_state; -mod server; -mod strategy_lb; - -use lb_state::{LBConfig, LBState}; -use server::{periodic_logging, startup}; -use tokio::signal; - -fn main() -> anyhow::Result<()> { - // FIXME: test code, move to test folder - let prefill_infos = (0..8) - .map(|i| (format!("123.123.123.123:{}", i), None)) - .collect::)>>(); - - let decode_infos = (0..32) - .map(|i| format!("233.233.233.233:{}", i)) - .collect::>(); - - let lb_config = LBConfig { - host: "localhost".to_string(), - port: 8080, - policy: "random".to_string(), - prefill_infos, - decode_infos, - log_interval: 5, - timeout: 600, - }; - let lb_state = LBState::new(lb_config.clone()).map_err(|e| anyhow::anyhow!(e))?; - let ret: anyhow::Result<()> = actix_web::rt::System::new().block_on(async move { - tokio::select! { - _ = periodic_logging(lb_state.clone()) => { - unreachable!() - } - res = startup(lb_config.clone(), lb_state) => { - res.map_err(|e| anyhow::anyhow!(e))?; - unreachable!() - } - _ = signal::ctrl_c() => { - println!("Received Ctrl+C, shutting down"); - std::process::exit(0); - } - } - }); - ret -} diff --git a/sgl-pdlb/src/server.rs b/sgl-pdlb/src/server.rs deleted file mode 100644 index b763c743b..000000000 --- a/sgl-pdlb/src/server.rs +++ /dev/null @@ -1,183 +0,0 @@ -use crate::io_struct::{ChatReqInput, GenerateReqInput}; -use crate::lb_state::{LBConfig, LBState}; -use crate::strategy_lb::EngineType; -use actix_web::{HttpRequest, HttpResponse, HttpServer, get, post, web}; -use reqwest::Method; -use serde_json::json; -use std::io::Write; - -#[get("/health")] -pub async fn health(_req: HttpRequest, _: web::Data) -> HttpResponse { - HttpResponse::Ok().body("Ok") -} - -#[get("/health_generate")] -pub async fn health_generate( - _req: HttpRequest, - app_state: web::Data, -) -> Result { - let servers = app_state.strategy_lb.get_all_servers(); - app_state - .route_collect(&servers, Method::GET, "/health_generate", None) - .await?; - // FIXME: log the response - Ok(HttpResponse::Ok().body("Health check passed on all servers")) -} - -#[post("/flush_cache")] -pub async fn flush_cache( - _req: HttpRequest, - app_state: web::Data, -) -> Result { - let servers = app_state.strategy_lb.get_all_servers(); - app_state - .route_collect(&servers, Method::POST, "/flush_cache", None) - .await?; - Ok(HttpResponse::Ok().body("Cache flushed on all servers")) -} - -#[get("/get_model_info")] -pub async fn get_model_info( - _req: HttpRequest, - app_state: web::Data, -) -> Result { - // Return the first server's model info - let engine = app_state.strategy_lb.get_one_server(); - app_state - .route_one(&engine, Method::GET, "/get_model_info", None, false) - .await? - .into() -} - -#[post("/generate")] -pub async fn generate( - _req: HttpRequest, - req: web::Json, - app_state: web::Data, -) -> Result { - app_state - .generate("/generate", Box::new(req.into_inner())) - .await -} - -#[post("/v1/completions")] -pub async fn completions( - _req: HttpRequest, - req: web::Json, - app_state: web::Data, -) -> Result { - app_state - .generate("/v1/completions", Box::new(req.into_inner())) - .await -} - -#[post("/v1/chat/completions")] -pub async fn chat_completions( - _req: HttpRequest, - req: web::Json, - app_state: web::Data, -) -> Result { - app_state - .generate("/v1/chat/completions", Box::new(req.into_inner())) - .await -} - -#[get("/get_server_info")] -pub async fn get_server_info( - _req: HttpRequest, - app_state: web::Data, -) -> Result { - let servers = app_state.strategy_lb.get_all_servers(); - let responses = app_state - .route_collect(&servers, Method::GET, "/get_server_info", None) - .await?; - let mut prefill_infos = Vec::new(); - let mut decode_infos = Vec::new(); - for (i, resp) in responses.iter().enumerate() { - let json = resp.to_json()?; - match servers[i].engine_type { - EngineType::Prefill => prefill_infos.push(json), - EngineType::Decode => decode_infos.push(json), - } - } - Ok(HttpResponse::Ok().json(json!({ - "prefill": prefill_infos, - "decode": decode_infos, - }))) -} - -#[get("/get_loads")] -pub async fn get_loads( - _req: HttpRequest, - app_state: web::Data, -) -> Result { - let (prefill_loads, decode_loads) = app_state.get_engine_loads().await?; - Ok(HttpResponse::Ok().json(json!({ - "prefill": prefill_loads.into_iter().map(|l| l.to_json()).collect::>(), - "decode": decode_loads.into_iter().map(|l| l.to_json()).collect::>() - }))) -} - -pub async fn periodic_logging(lb_state: LBState) { - // FIXME: currently we can just clone the lb_state to log as the lb is stateless - loop { - tokio::time::sleep(std::time::Duration::from_secs(lb_state.log_interval)).await; - let (prefill_loads, decode_loads) = match lb_state.get_engine_loads().await { - Ok((prefill_loads, decode_loads)) => (prefill_loads, decode_loads), - Err(e) => { - log::error!("Failed to get engine loads: {}", e); - continue; - } - }; - let prefill_loads = prefill_loads - .into_iter() - .map(|l| l.to_string()) - .collect::>(); - let decode_loads = decode_loads - .into_iter() - .map(|l| l.to_string()) - .collect::>(); - log::info!("Prefill loads: {}", prefill_loads.join(", ")); - log::info!("Decode loads: {}", decode_loads.join(", ")); - } -} - -pub async fn startup(lb_config: LBConfig, lb_state: LBState) -> std::io::Result<()> { - let app_state = web::Data::new(lb_state); - - println!("Starting server at {}:{}", lb_config.host, lb_config.port); - - // default level is info - env_logger::Builder::new() - .format(|buf, record| { - writeln!( - buf, - "{} - {} - {}", - chrono::Local::now().format("%Y-%m-%d %H:%M:%S"), - record.level(), - record.args() - ) - }) - .filter(None, log::LevelFilter::Info) - .init(); - - HttpServer::new(move || { - actix_web::App::new() - .wrap(actix_web::middleware::Logger::default()) - .app_data(app_state.clone()) - .service(health) - .service(health_generate) - .service(flush_cache) - .service(get_model_info) - .service(get_server_info) - .service(get_loads) - .service(generate) - .service(chat_completions) - .service(completions) - }) - .bind((lb_config.host, lb_config.port))? - .run() - .await?; - - std::io::Result::Ok(()) -} diff --git a/sgl-pdlb/src/strategy_lb.rs b/sgl-pdlb/src/strategy_lb.rs deleted file mode 100644 index aeb5aba26..000000000 --- a/sgl-pdlb/src/strategy_lb.rs +++ /dev/null @@ -1,182 +0,0 @@ -use rand::Rng; -use serde_json::json; - -#[derive(Debug, Clone)] -pub enum EngineType { - Prefill, - Decode, -} - -#[derive(Debug, Clone)] -pub struct EngineInfo { - pub engine_type: EngineType, - pub url: String, - pub bootstrap_port: Option, -} - -impl EngineInfo { - pub fn new_prefill(url: String, bootstrap_port: Option) -> Self { - EngineInfo { - engine_type: EngineType::Prefill, - url, - bootstrap_port, - } - } - - pub fn new_decode(url: String) -> Self { - EngineInfo { - engine_type: EngineType::Decode, - url, - bootstrap_port: None, - } - } - - pub fn api_path(&self, api_path: &str) -> String { - if api_path.starts_with("/") { - format!("{}{}", self.url, api_path) - } else { - format!("{}/{}", self.url, api_path) - } - } - - pub fn to_string(&self) -> String { - format!("({:?}@{})", self.engine_type, self.url) - } - - pub fn get_hostname(&self) -> String { - let url = self - .url - .trim_start_matches("http://") - .trim_start_matches("https://"); - url.split(':').next().unwrap().to_string() - } -} - -pub struct EngineLoad { - pub engine_info: EngineInfo, - pub load: isize, -} - -impl EngineLoad { - pub fn from_json(engine_info: &EngineInfo, json: &serde_json::Value) -> Self { - let load = match json.get("load") { - Some(load) => load.as_i64().unwrap_or(-1) as isize, - None => -1, - }; - EngineLoad { - engine_info: engine_info.clone(), - load, - } - } - pub fn to_json(&self) -> serde_json::Value { - json!({ - "engine": self.engine_info.to_string(), - "load": self.load, - }) - } - - pub fn to_string(&self) -> String { - format!("{}: {}", self.engine_info.to_string(), self.load) - } -} - -#[derive(Debug, Clone)] -pub enum LBPolicy { - Random, - PowerOfTwo, -} - -#[derive(Debug, Clone)] -pub struct StrategyLB { - pub policy: LBPolicy, - pub prefill_servers: Vec, - pub decode_servers: Vec, -} - -impl StrategyLB { - pub fn new( - policy: LBPolicy, - prefill_servers: Vec, - decode_servers: Vec, - ) -> Self { - StrategyLB { - policy, - prefill_servers, - decode_servers, - } - } - - pub fn get_one_server(&self) -> EngineInfo { - assert!(!self.prefill_servers.is_empty()); - assert!(!self.decode_servers.is_empty()); - self.prefill_servers[0].clone() - } - - pub fn get_all_servers(&self) -> Vec { - let mut all_servers = Vec::new(); - all_servers.extend(self.prefill_servers.clone()); - all_servers.extend(self.decode_servers.clone()); - all_servers - } - - pub async fn select_pair(&self, client: &reqwest::Client) -> (EngineInfo, EngineInfo) { - match self.policy { - LBPolicy::Random => self.select_pd_pair_random(), - LBPolicy::PowerOfTwo => self.select_pd_pair_po2(client).await, - } - } - - fn select_pd_pair_random(&self) -> (EngineInfo, EngineInfo) { - let mut rng = rand::rng(); - let prefill_index = rng.random_range(0..self.prefill_servers.len()); - let decode_index = rng.random_range(0..self.decode_servers.len()); - - ( - self.prefill_servers[prefill_index].clone(), - self.decode_servers[decode_index].clone(), - ) - } - - async fn get_load_from_engine( - &self, - client: &reqwest::Client, - engine_info: &EngineInfo, - ) -> Option { - let url = engine_info.api_path("/get_load"); - let response = client.get(url).send().await.unwrap(); - match response.status() { - reqwest::StatusCode::OK => { - let data = response.json::().await.unwrap(); - Some(data["load"].as_i64().unwrap() as isize) - } - _ => None, - } - } - - async fn select_pd_pair_po2(&self, client: &reqwest::Client) -> (EngineInfo, EngineInfo) { - let mut rng = rand::rng(); - let prefill1 = - self.prefill_servers[rng.random_range(0..self.prefill_servers.len())].clone(); - let prefill2 = - self.prefill_servers[rng.random_range(0..self.prefill_servers.len())].clone(); - let decode1 = self.decode_servers[rng.random_range(0..self.decode_servers.len())].clone(); - let decode2 = self.decode_servers[rng.random_range(0..self.decode_servers.len())].clone(); - let prefill1_load = self.get_load_from_engine(client, &prefill1).await; - let prefill2_load = self.get_load_from_engine(client, &prefill2).await; - let decode1_load = self.get_load_from_engine(client, &decode1).await; - let decode2_load = self.get_load_from_engine(client, &decode2).await; - - ( - if prefill1_load < prefill2_load { - prefill1 - } else { - prefill2 - }, - if decode1_load < decode2_load { - decode1 - } else { - decode2 - }, - ) - } -}