diff --git a/python/sglang/srt/disaggregation/launch_lb.py b/python/sglang/srt/disaggregation/launch_lb.py new file mode 100644 index 000000000..96ffe48bf --- /dev/null +++ b/python/sglang/srt/disaggregation/launch_lb.py @@ -0,0 +1,140 @@ +import argparse +import dataclasses + + +@dataclasses.dataclass +class LBArgs: + rust_lb: bool = False + host: str = "0.0.0.0" + port: int = 8000 + policy: str = "random" + prefill_infos: list = dataclasses.field(default_factory=list) + decode_infos: list = dataclasses.field(default_factory=list) + log_interval: int = 5 + timeout: int = 600 + + @staticmethod + def add_cli_args(parser: argparse.ArgumentParser): + parser.add_argument( + "--rust-lb", + action="store_true", + help="Use Rust load balancer", + ) + parser.add_argument( + "--host", + type=str, + default=LBArgs.host, + help=f"Host to bind the server (default: {LBArgs.host})", + ) + parser.add_argument( + "--port", + type=int, + default=LBArgs.port, + help=f"Port to bind the server (default: {LBArgs.port})", + ) + parser.add_argument( + "--policy", + type=str, + default=LBArgs.policy, + choices=["random", "po2"], + help=f"Policy to use for load balancing (default: {LBArgs.policy})", + ) + parser.add_argument( + "--prefill", + type=str, + default=[], + nargs="+", + help="URLs for prefill servers", + ) + parser.add_argument( + "--decode", + type=str, + default=[], + nargs="+", + help="URLs for decode servers", + ) + parser.add_argument( + "--prefill-bootstrap-ports", + type=int, + nargs="+", + help="Bootstrap ports for prefill servers", + ) + parser.add_argument( + "--log-interval", + type=int, + default=LBArgs.log_interval, + help=f"Log interval in seconds (default: {LBArgs.log_interval})", + ) + parser.add_argument( + "--timeout", + type=int, + default=LBArgs.timeout, + help=f"Timeout in seconds (default: {LBArgs.timeout})", + ) + + @classmethod + def from_cli_args(cls, args: argparse.Namespace) -> "LBArgs": + bootstrap_ports = args.prefill_bootstrap_ports + if bootstrap_ports is None: + bootstrap_ports = [None] * len(args.prefill) + elif len(bootstrap_ports) == 1: + bootstrap_ports = bootstrap_ports * len(args.prefill) + else: + if len(bootstrap_ports) != len(args.prefill): + raise ValueError( + "Number of prefill URLs must match number of bootstrap ports" + ) + + prefill_infos = [ + (url, port) for url, port in zip(args.prefill, bootstrap_ports) + ] + + return cls( + rust_lb=args.rust_lb, + host=args.host, + port=args.port, + policy=args.policy, + prefill_infos=prefill_infos, + decode_infos=args.decode, + log_interval=args.log_interval, + timeout=args.timeout, + ) + + def __post_init__(self): + if not self.rust_lb: + assert ( + self.policy == "random" + ), "Only random policy is supported for Python load balancer" + + +def main(): + parser = argparse.ArgumentParser( + description="PD Disaggregation Load Balancer Server" + ) + LBArgs.add_cli_args(parser) + args = parser.parse_args() + lb_args = LBArgs.from_cli_args(args) + + if lb_args.rust_lb: + from sgl_pdlb._rust import LoadBalancer as RustLB + + RustLB( + host=lb_args.host, + port=lb_args.port, + policy=lb_args.policy, + prefill_infos=lb_args.prefill_infos, + decode_infos=lb_args.decode_infos, + log_interval=lb_args.log_interval, + timeout=lb_args.timeout, + ).start() + else: + from sglang.srt.disaggregation.mini_lb import PrefillConfig, run + + prefill_configs = [ + PrefillConfig(url, port) for url, port in lb_args.prefill_infos + ] + run(prefill_configs, lb_args.decode_infos, lb_args.host, lb_args.port) + + +if __name__ == "__main__": + main() diff --git a/python/sglang/srt/disaggregation/mini_lb.py b/python/sglang/srt/disaggregation/mini_lb.py index 883734f4e..c7e0a2089 100644 --- a/python/sglang/srt/disaggregation/mini_lb.py +++ b/python/sglang/srt/disaggregation/mini_lb.py @@ -377,42 +377,7 @@ def run(prefill_configs, decode_addrs, host, port): if __name__ == "__main__": - import argparse + # FIXME: remove this, use the unified entry point: sglang.srt.disaggregation.launch_lb + from sglang.srt.disaggregation.launch_lb import main - parser = argparse.ArgumentParser(description="Mini Load Balancer Server") - parser.add_argument( - "--prefill", type=str, default=[], nargs="+", help="URLs for prefill servers" - ) - parser.add_argument( - "--decode", type=str, default=[], nargs="+", help="URLs for decode servers" - ) - parser.add_argument( - "--prefill-bootstrap-ports", - type=int, - nargs="+", - help="Bootstrap ports for prefill servers", - ) - parser.add_argument( - "--host", default="0.0.0.0", help="Host to bind the server (default: 0.0.0.0)" - ) - parser.add_argument( - "--port", type=int, default=8000, help="Port to bind the server (default: 8000)" - ) - args = parser.parse_args() - - bootstrap_ports = args.prefill_bootstrap_ports - if bootstrap_ports is None: - bootstrap_ports = [None] * len(args.prefill) - elif len(bootstrap_ports) == 1: - bootstrap_ports = bootstrap_ports * len(args.prefill) - else: - if len(bootstrap_ports) != len(args.prefill): - raise ValueError( - "Number of prefill URLs must match number of bootstrap ports" - ) - - prefill_configs = [ - PrefillConfig(url, port) for url, port in zip(args.prefill, bootstrap_ports) - ] - - run(prefill_configs, args.decode, args.host, args.port) + main() diff --git a/python/sglang/srt/entrypoints/http_server.py b/python/sglang/srt/entrypoints/http_server.py index ff0978e38..1c6892bca 100644 --- a/python/sglang/srt/entrypoints/http_server.py +++ b/python/sglang/srt/entrypoints/http_server.py @@ -229,6 +229,11 @@ async def get_server_info(): } +@app.get("/get_load") +async def get_load(): + return await _global_state.tokenizer_manager.get_load() + + @app.api_route("/set_internal_state", methods=["POST", "PUT"]) async def set_internal_state(obj: SetInternalStateReq, request: Request): res = await _global_state.tokenizer_manager.set_internal_state(obj) diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index 1dd9c519e..48e6fd6b6 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -103,7 +103,7 @@ class GenerateReqInput: # For disaggregated inference bootstrap_host: Optional[Union[List[str], str]] = None - bootstrap_port: Optional[Union[List[int], int]] = None + bootstrap_port: Optional[Union[List[Optional[int]], int]] = None bootstrap_room: Optional[Union[List[int], int]] = None def contains_mm_input(self) -> bool: diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 6c6eb748a..9c9a8f6a9 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -1911,6 +1911,27 @@ class Scheduler( if_success = False return if_success + def get_load(self): + # TODO(lsyin): use dynamically maintained num_waiting_tokens + load = ( + self.max_total_num_tokens + - self.token_to_kv_pool_allocator.available_size() + - self.tree_cache.evictable_size() + ) + load += sum(len(req.origin_input_ids) for req in self.waiting_queue) + if self.disaggregation_mode == DisaggregationMode.PREFILL: + load += sum( + len(req.origin_input_ids) + for req in self.disagg_prefill_bootstrap_queue.queue + ) + elif self.disaggregation_mode == DisaggregationMode.DECODE: + load += sum( + len(req.req.origin_input_ids) + for req in self.disagg_decode_prealloc_queue.queue + ) + + return load + def get_internal_state(self, recv_req: GetInternalStateReq): ret = dict(global_server_args_dict) ret["last_gen_throughput"] = self.last_gen_throughput @@ -1920,9 +1941,10 @@ class Scheduler( ) if RECORD_STEP_TIME: ret["step_time_dict"] = self.step_time_dict - return GetInternalStateReqOutput( - internal_state=ret, - ) + + ret["load"] = self.get_load() + + return GetInternalStateReqOutput(internal_state=ret) def set_internal_state(self, recv_req: SetInternalStateReq): server_args_dict = recv_req.server_args diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 8af7bf0f8..cc0b30038 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -395,6 +395,9 @@ class TokenizerManager: self.server_args.disaggregation_bootstrap_port ) + self.current_load = 0 + self.current_load_lock = asyncio.Lock() + async def generate_request( self, obj: Union[GenerateReqInput, EmbeddingReqInput], @@ -983,6 +986,14 @@ class TokenizerManager: # Many DP ranks return [res.internal_state for res in responses] + async def get_load(self) -> dict: + # TODO(lsyin): fake load report server + if not self.current_load_lock.locked(): + async with self.current_load_lock: + internal_state = await self.get_internal_state() + self.current_load = internal_state[0]["load"] + return {"load": self.current_load} + async def set_internal_state( self, obj: SetInternalStateReq ) -> SetInternalStateReqOutput: diff --git a/sgl-pdlb/.rustfmt.toml b/sgl-pdlb/.rustfmt.toml new file mode 100644 index 000000000..745fb75b4 --- /dev/null +++ b/sgl-pdlb/.rustfmt.toml @@ -0,0 +1,2 @@ +reorder_imports = true +reorder_modules = true diff --git a/sgl-pdlb/Cargo.toml b/sgl-pdlb/Cargo.toml new file mode 100644 index 000000000..bcfe8e1de --- /dev/null +++ b/sgl-pdlb/Cargo.toml @@ -0,0 +1,28 @@ +[package] +edition = "2024" +name = "sgl-pdlb" +version = "0.1.0" + +[lib] +crate-type = ["cdylib", "rlib"] +name = "sgl_pdlb_rs" + +[dependencies] +actix-web = "4.11" +bytes = "1.8.0" +chrono = "0.4.38" +clap = { version = "4.4", features = ["derive"] } +dashmap = "6.1.0" +env_logger = "0.11.5" +futures = "0.3" +futures-util = "0.3" +http = "1.3.1" +log = "0.4.22" +pyo3 = { version = "0.25.0", features = ["extension-module"] } +rand = "0.9.0" +reqwest = { version = "0.12.8", features = ["stream", "blocking", "json"] } +serde = { version = "1.0", features = ["derive"] } +serde_json = "1.0" +tokio = { version = "1.34", features = ["full"] } +anyhow = "1.0.98" +typetag = "0.2.20" diff --git a/sgl-pdlb/README.md b/sgl-pdlb/README.md new file mode 100644 index 000000000..c763ed501 --- /dev/null +++ b/sgl-pdlb/README.md @@ -0,0 +1,12 @@ +### Install dependencies + +```bash +pip install "maturin[patchelf]" +``` + +### Build and install + +```bash +maturin develop +pip install -e . +``` diff --git a/sgl-pdlb/py_src/sgl_pdlb/__init__.py b/sgl-pdlb/py_src/sgl_pdlb/__init__.py new file mode 100644 index 000000000..f102a9cad --- /dev/null +++ b/sgl-pdlb/py_src/sgl_pdlb/__init__.py @@ -0,0 +1 @@ +__version__ = "0.0.1" diff --git a/sgl-pdlb/pyproject.toml b/sgl-pdlb/pyproject.toml new file mode 100644 index 000000000..4a0f80b75 --- /dev/null +++ b/sgl-pdlb/pyproject.toml @@ -0,0 +1,14 @@ +[build-system] +requires = ["maturin>=1.8.0"] +build-backend = "maturin" + +[project] +name = "sgl_pdlb" +version = "0.0.1" + +[tool.maturin] +python-source = "py_src" +module-name = "sgl_pdlb._rust" + +[tool.maturin.build-backend] +features = ["pyo3/extension-module"] diff --git a/sgl-pdlb/src/io_struct.rs b/sgl-pdlb/src/io_struct.rs new file mode 100644 index 000000000..804aca58b --- /dev/null +++ b/sgl-pdlb/src/io_struct.rs @@ -0,0 +1,133 @@ +use crate::strategy_lb::EngineInfo; +use serde::{Deserialize, Serialize}; +use serde_json::Value; + +#[derive(Debug, Deserialize, Serialize)] +#[serde(untagged)] +pub enum SingleOrBatch { + Single(T), + Batch(Vec), +} + +pub type InputIds = SingleOrBatch>; +pub type InputText = SingleOrBatch; +pub type BootstrapHost = SingleOrBatch; +pub type BootstrapPort = SingleOrBatch>; +pub type BootstrapRoom = SingleOrBatch; + +#[typetag::serde(tag = "type")] +pub trait Bootstrap { + fn is_stream(&self) -> bool; + fn get_batch_size(&self) -> Result, actix_web::Error>; + fn set_bootstrap_info( + &mut self, + bootstrap_host: BootstrapHost, + bootstrap_port: BootstrapPort, + bootstrap_room: BootstrapRoom, + ); + + fn add_bootstrap_info(&mut self, prefill_info: &EngineInfo) -> Result<(), actix_web::Error> { + let batch_size = self.get_batch_size()?; + if let Some(batch_size) = batch_size { + self.set_bootstrap_info( + BootstrapHost::Batch(vec![prefill_info.get_hostname(); batch_size]), + BootstrapPort::Batch(vec![prefill_info.bootstrap_port; batch_size]), + BootstrapRoom::Batch((0..batch_size).map(|_| rand::random::()).collect()), + ); + } else { + self.set_bootstrap_info( + BootstrapHost::Single(prefill_info.get_hostname()), + BootstrapPort::Single(prefill_info.bootstrap_port), + BootstrapRoom::Single(rand::random::()), + ); + } + Ok(()) + } +} + +#[derive(Debug, Deserialize, Serialize)] +pub struct GenerateReqInput { + pub text: Option, + pub input_ids: Option, + #[serde(default)] + pub stream: bool, + pub bootstrap_host: Option, + pub bootstrap_port: Option, + pub bootstrap_room: Option, + + #[serde(flatten)] + pub other: Value, +} + +impl GenerateReqInput { + pub fn get_batch_size(&self) -> Result, actix_web::Error> { + if self.text.is_some() && self.input_ids.is_some() { + return Err(actix_web::error::ErrorBadRequest( + "Both text and input_ids are present in the request".to_string(), + )); + } + if let Some(InputText::Batch(texts)) = &self.text { + return Ok(Some(texts.len())); + } + if let Some(InputIds::Batch(ids)) = &self.input_ids { + return Ok(Some(ids.len())); + } + Ok(None) + } +} + +#[typetag::serde] +impl Bootstrap for GenerateReqInput { + fn is_stream(&self) -> bool { + self.stream + } + + fn get_batch_size(&self) -> Result, actix_web::Error> { + self.get_batch_size() + } + + fn set_bootstrap_info( + &mut self, + bootstrap_host: BootstrapHost, + bootstrap_port: BootstrapPort, + bootstrap_room: BootstrapRoom, + ) { + self.bootstrap_host = Some(bootstrap_host); + self.bootstrap_port = Some(bootstrap_port); + self.bootstrap_room = Some(bootstrap_room); + } +} + +#[derive(Debug, Deserialize, Serialize)] +pub struct ChatReqInput { + #[serde(default)] + pub stream: bool, + pub bootstrap_host: Option, + pub bootstrap_port: Option, + pub bootstrap_room: Option, + + #[serde(flatten)] + pub other: Value, +} + +#[typetag::serde] +impl Bootstrap for ChatReqInput { + fn is_stream(&self) -> bool { + self.stream + } + + fn get_batch_size(&self) -> Result, actix_web::Error> { + Ok(None) + } + + fn set_bootstrap_info( + &mut self, + bootstrap_host: BootstrapHost, + bootstrap_port: BootstrapPort, + bootstrap_room: BootstrapRoom, + ) { + self.bootstrap_host = Some(bootstrap_host); + self.bootstrap_port = Some(bootstrap_port); + self.bootstrap_room = Some(bootstrap_room); + } +} diff --git a/sgl-pdlb/src/lb_state.rs b/sgl-pdlb/src/lb_state.rs new file mode 100644 index 000000000..075ce43ef --- /dev/null +++ b/sgl-pdlb/src/lb_state.rs @@ -0,0 +1,175 @@ +use crate::io_struct::Bootstrap; +use crate::strategy_lb::{EngineInfo, EngineLoad, EngineType, LBPolicy, StrategyLB}; +use actix_web::HttpResponse; +use bytes::Bytes; +use futures::{Stream, StreamExt, future::join_all}; +use reqwest::{Method, StatusCode}; +use std::pin::Pin; + +pub enum ProxyResponseBody { + Full(Bytes), + Stream(Pin> + Send>>), +} + +pub struct ProxyResponse { + pub status: StatusCode, + pub body: ProxyResponseBody, +} + +impl ProxyResponse { + pub fn to_json(&self) -> Result { + match &self.body { + ProxyResponseBody::Full(body) => Ok(serde_json::from_slice(&body)?), + ProxyResponseBody::Stream(_) => Err(actix_web::error::ErrorBadRequest( + "Stream response is not supported", + )), + } + } +} + +impl Into> for ProxyResponse { + fn into(self) -> Result { + let status = actix_web::http::StatusCode::from_u16(self.status.as_u16()).map_err(|e| { + actix_web::error::ErrorBadGateway(format!("Invalid status code: {}", e)) + })?; + match self.body { + ProxyResponseBody::Full(body) => Ok(HttpResponse::Ok().status(status).body(body)), + ProxyResponseBody::Stream(body) => Ok(HttpResponse::Ok() + .status(status) + .content_type("application/octet-stream") + .streaming(body)), + } + } +} + +#[derive(Debug, Clone)] +pub struct LBConfig { + pub host: String, + pub port: u16, + pub policy: String, + pub prefill_infos: Vec<(String, Option)>, + pub decode_infos: Vec, + pub log_interval: u64, + pub timeout: u64, +} + +#[derive(Debug, Clone)] +pub struct LBState { + pub strategy_lb: StrategyLB, + pub client: reqwest::Client, + pub log_interval: u64, +} + +impl LBState { + pub fn new(lb_config: LBConfig) -> anyhow::Result { + let client = reqwest::Client::builder() + .timeout(std::time::Duration::from_secs(lb_config.timeout)) + .build()?; + let policy = match lb_config.policy.as_str() { + "random" => LBPolicy::Random, + "po2" => LBPolicy::PowerOfTwo, + _ => anyhow::bail!("Invalid policy"), + }; + let prefill_servers = lb_config + .prefill_infos + .into_iter() + .map(|(url, port)| EngineInfo::new_prefill(url, port)) + .collect(); + let decode_servers = lb_config + .decode_infos + .into_iter() + .map(|url| EngineInfo::new_decode(url)) + .collect(); + let lb = StrategyLB::new(policy, prefill_servers, decode_servers); + Ok(Self { + strategy_lb: lb, + client, + log_interval: lb_config.log_interval, + }) + } + + pub async fn route_one( + &self, + engine_info: &EngineInfo, + method: Method, + api_path: &str, + request: Option<&serde_json::Value>, + stream: bool, + ) -> Result { + let url = engine_info.api_path(api_path); + let request = request.unwrap_or(&serde_json::Value::Null); + let task = self.client.request(method, url).json(request).send(); + let resp = task.await.map_err(actix_web::error::ErrorBadGateway)?; + // FIXME: handle error status code (map status code to error) + let status = resp.status(); + let body = if stream { + let resp_stream = resp.bytes_stream().map(|r| { + r.map_err(actix_web::error::ErrorBadGateway) + .map(Bytes::from) + }); + ProxyResponseBody::Stream(Box::pin(resp_stream)) + } else { + let body = resp + .bytes() + .await + .map_err(actix_web::error::ErrorBadGateway)?; + ProxyResponseBody::Full(body) + }; + Ok(ProxyResponse { status, body }) + } + + pub async fn route_collect( + &self, + engines: &Vec, + method: Method, + api_path: &str, + request: Option<&serde_json::Value>, + ) -> Result, actix_web::Error> { + let tasks = engines + .iter() + .map(|engine| self.route_one(engine, method.clone(), api_path, request, false)); + let responses = join_all(tasks).await; + responses + .into_iter() + .map(|r| r.map_err(actix_web::error::ErrorBadGateway)) + .collect() + } + + pub async fn generate( + &self, + api_path: &str, + mut req: Box, + ) -> Result { + let (prefill, decode) = self.strategy_lb.select_pair(&self.client).await; + let stream = req.is_stream(); + req.add_bootstrap_info(&prefill)?; + let json = serde_json::to_value(req)?; + let prefill_task = self.route_one(&prefill, Method::POST, api_path, Some(&json), false); + let decode_task = self.route_one(&decode, Method::POST, api_path, Some(&json), stream); + let (_, decode_response) = tokio::join!(prefill_task, decode_task); + decode_response?.into() + } + + pub async fn get_engine_loads( + &self, + ) -> Result<(Vec, Vec), actix_web::Error> { + let servers = self.strategy_lb.get_all_servers(); + let responses = self + .route_collect(&servers, Method::GET, "/get_load", None) + .await?; + let loads = responses + .into_iter() + .enumerate() + .map(|(i, r)| Ok(EngineLoad::from_json(&servers[i], &r.to_json()?))) + .collect::, actix_web::Error>>()?; + let mut prefill_loads = Vec::new(); + let mut decode_loads = Vec::new(); + for load in loads { + match load.engine_info.engine_type { + EngineType::Prefill => prefill_loads.push(load), + EngineType::Decode => decode_loads.push(load), + } + } + Ok((prefill_loads, decode_loads)) + } +} diff --git a/sgl-pdlb/src/lib.rs b/sgl-pdlb/src/lib.rs new file mode 100644 index 000000000..097b86aca --- /dev/null +++ b/sgl-pdlb/src/lib.rs @@ -0,0 +1,68 @@ +pub mod io_struct; +pub mod lb_state; +pub mod server; +pub mod strategy_lb; +use pyo3::{exceptions::PyRuntimeError, prelude::*}; + +use lb_state::{LBConfig, LBState}; +use server::{periodic_logging, startup}; +use tokio::signal; + +#[pyclass] +pub struct LoadBalancer { + lb_config: LBConfig, +} + +#[pymethods] +impl LoadBalancer { + #[new] + pub fn new( + host: String, + port: u16, + policy: String, + prefill_infos: Vec<(String, Option)>, + decode_infos: Vec, + log_interval: u64, + timeout: u64, + ) -> PyResult { + let lb_config = LBConfig { + host, + port, + policy, + prefill_infos, + decode_infos, + log_interval, + timeout, + }; + Ok(LoadBalancer { lb_config }) + } + + pub fn start(&self) -> PyResult<()> { + let lb_state = LBState::new(self.lb_config.clone()).map_err(|e| { + PyRuntimeError::new_err(format!("Failed to build load balancer: {}", e)) + })?; + + let ret: PyResult<()> = actix_web::rt::System::new().block_on(async move { + tokio::select! { + _ = periodic_logging(lb_state.clone()) => { + unreachable!() + } + res = startup(self.lb_config.clone(), lb_state) => { + res.map_err(|e| PyRuntimeError::new_err(e.to_string()))?; + unreachable!() + } + _ = signal::ctrl_c() => { + println!("Received Ctrl+C, shutting down"); + std::process::exit(0); + } + } + }); + ret + } +} + +#[pymodule] +fn _rust(_py: Python, m: &Bound) -> PyResult<()> { + m.add_class::()?; + Ok(()) +} diff --git a/sgl-pdlb/src/main.rs b/sgl-pdlb/src/main.rs new file mode 100644 index 000000000..7125aa4ed --- /dev/null +++ b/sgl-pdlb/src/main.rs @@ -0,0 +1,46 @@ +mod io_struct; +mod lb_state; +mod server; +mod strategy_lb; + +use lb_state::{LBConfig, LBState}; +use server::{periodic_logging, startup}; +use tokio::signal; + +fn main() -> anyhow::Result<()> { + // FIXME: test code, move to test folder + let prefill_infos = (0..8) + .map(|i| (format!("123.123.123.123:{}", i), None)) + .collect::)>>(); + + let decode_infos = (0..32) + .map(|i| format!("233.233.233.233:{}", i)) + .collect::>(); + + let lb_config = LBConfig { + host: "localhost".to_string(), + port: 8080, + policy: "random".to_string(), + prefill_infos, + decode_infos, + log_interval: 5, + timeout: 600, + }; + let lb_state = LBState::new(lb_config.clone()).map_err(|e| anyhow::anyhow!(e))?; + let ret: anyhow::Result<()> = actix_web::rt::System::new().block_on(async move { + tokio::select! { + _ = periodic_logging(lb_state.clone()) => { + unreachable!() + } + res = startup(lb_config.clone(), lb_state) => { + res.map_err(|e| anyhow::anyhow!(e))?; + unreachable!() + } + _ = signal::ctrl_c() => { + println!("Received Ctrl+C, shutting down"); + std::process::exit(0); + } + } + }); + ret +} diff --git a/sgl-pdlb/src/server.rs b/sgl-pdlb/src/server.rs new file mode 100644 index 000000000..03af2694a --- /dev/null +++ b/sgl-pdlb/src/server.rs @@ -0,0 +1,171 @@ +use crate::io_struct::{ChatReqInput, GenerateReqInput}; +use crate::lb_state::{LBConfig, LBState}; +use crate::strategy_lb::EngineType; +use actix_web::{HttpRequest, HttpResponse, HttpServer, get, post, web}; +use reqwest::Method; +use serde_json::json; +use std::io::Write; + +#[get("/health")] +pub async fn health(_req: HttpRequest, _: web::Data) -> HttpResponse { + HttpResponse::Ok().body("Ok") +} + +#[get("/health_generate")] +pub async fn health_generate( + _req: HttpRequest, + app_state: web::Data, +) -> Result { + let servers = app_state.strategy_lb.get_all_servers(); + app_state + .route_collect(&servers, Method::GET, "/health_generate", None) + .await?; + // FIXME: log the response + Ok(HttpResponse::Ok().body("Health check passed on all servers")) +} + +#[post("/flush_cache")] +pub async fn flush_cache( + _req: HttpRequest, + app_state: web::Data, +) -> Result { + let servers = app_state.strategy_lb.get_all_servers(); + app_state + .route_collect(&servers, Method::POST, "/flush_cache", None) + .await?; + Ok(HttpResponse::Ok().body("Cache flushed on all servers")) +} + +#[get("/get_model_info")] +pub async fn get_model_info( + _req: HttpRequest, + app_state: web::Data, +) -> Result { + // Return the first server's model info + let engine = app_state.strategy_lb.get_one_server(); + app_state + .route_one(&engine, Method::GET, "/get_model_info", None, false) + .await? + .into() +} + +#[post("/generate")] +pub async fn generate( + _req: HttpRequest, + req: web::Json, + app_state: web::Data, +) -> Result { + app_state + .generate("/generate", Box::new(req.into_inner())) + .await +} + +#[post("/v1/chat/completions")] +pub async fn chat_completions( + _req: HttpRequest, + req: web::Json, + app_state: web::Data, +) -> Result { + app_state + .generate("/v1/chat/completions", Box::new(req.into_inner())) + .await +} + +#[get("/get_server_info")] +pub async fn get_server_info( + _req: HttpRequest, + app_state: web::Data, +) -> Result { + let servers = app_state.strategy_lb.get_all_servers(); + let responses = app_state + .route_collect(&servers, Method::GET, "/get_server_info", None) + .await?; + let mut prefill_infos = Vec::new(); + let mut decode_infos = Vec::new(); + for (i, resp) in responses.iter().enumerate() { + let json = resp.to_json()?; + match servers[i].engine_type { + EngineType::Prefill => prefill_infos.push(json), + EngineType::Decode => decode_infos.push(json), + } + } + Ok(HttpResponse::Ok().json(json!({ + "prefill": prefill_infos, + "decode": decode_infos, + }))) +} + +#[get("/get_loads")] +pub async fn get_loads( + _req: HttpRequest, + app_state: web::Data, +) -> Result { + let (prefill_loads, decode_loads) = app_state.get_engine_loads().await?; + Ok(HttpResponse::Ok().json(json!({ + "prefill": prefill_loads.into_iter().map(|l| l.to_json()).collect::>(), + "decode": decode_loads.into_iter().map(|l| l.to_json()).collect::>() + }))) +} + +pub async fn periodic_logging(lb_state: LBState) { + // FIXME: currently we can just clone the lb_state to log as the lb is stateless + loop { + tokio::time::sleep(std::time::Duration::from_secs(lb_state.log_interval)).await; + let (prefill_loads, decode_loads) = match lb_state.get_engine_loads().await { + Ok((prefill_loads, decode_loads)) => (prefill_loads, decode_loads), + Err(e) => { + log::error!("Failed to get engine loads: {}", e); + continue; + } + }; + let prefill_loads = prefill_loads + .into_iter() + .map(|l| l.to_string()) + .collect::>(); + let decode_loads = decode_loads + .into_iter() + .map(|l| l.to_string()) + .collect::>(); + log::info!("Prefill loads: {}", prefill_loads.join(", ")); + log::info!("Decode loads: {}", decode_loads.join(", ")); + } +} + +pub async fn startup(lb_config: LBConfig, lb_state: LBState) -> std::io::Result<()> { + let app_state = web::Data::new(lb_state); + + println!("Starting server at {}:{}", lb_config.host, lb_config.port); + + // default level is info + env_logger::Builder::new() + .format(|buf, record| { + writeln!( + buf, + "{} - {} - {}", + chrono::Local::now().format("%Y-%m-%d %H:%M:%S"), + record.level(), + record.args() + ) + }) + .filter(None, log::LevelFilter::Info) + .init(); + + HttpServer::new(move || { + actix_web::App::new() + .wrap(actix_web::middleware::Logger::default()) + .app_data(app_state.clone()) + .service(health) + .service(health_generate) + .service(flush_cache) + .service(get_model_info) + .service(get_server_info) + .service(get_loads) + .service(generate) + .service(chat_completions) + }) + .bind((lb_config.host, lb_config.port))? + .run() + .await?; + + std::io::Result::Ok(()) +} diff --git a/sgl-pdlb/src/strategy_lb.rs b/sgl-pdlb/src/strategy_lb.rs new file mode 100644 index 000000000..aeb5aba26 --- /dev/null +++ b/sgl-pdlb/src/strategy_lb.rs @@ -0,0 +1,182 @@ +use rand::Rng; +use serde_json::json; + +#[derive(Debug, Clone)] +pub enum EngineType { + Prefill, + Decode, +} + +#[derive(Debug, Clone)] +pub struct EngineInfo { + pub engine_type: EngineType, + pub url: String, + pub bootstrap_port: Option, +} + +impl EngineInfo { + pub fn new_prefill(url: String, bootstrap_port: Option) -> Self { + EngineInfo { + engine_type: EngineType::Prefill, + url, + bootstrap_port, + } + } + + pub fn new_decode(url: String) -> Self { + EngineInfo { + engine_type: EngineType::Decode, + url, + bootstrap_port: None, + } + } + + pub fn api_path(&self, api_path: &str) -> String { + if api_path.starts_with("/") { + format!("{}{}", self.url, api_path) + } else { + format!("{}/{}", self.url, api_path) + } + } + + pub fn to_string(&self) -> String { + format!("({:?}@{})", self.engine_type, self.url) + } + + pub fn get_hostname(&self) -> String { + let url = self + .url + .trim_start_matches("http://") + .trim_start_matches("https://"); + url.split(':').next().unwrap().to_string() + } +} + +pub struct EngineLoad { + pub engine_info: EngineInfo, + pub load: isize, +} + +impl EngineLoad { + pub fn from_json(engine_info: &EngineInfo, json: &serde_json::Value) -> Self { + let load = match json.get("load") { + Some(load) => load.as_i64().unwrap_or(-1) as isize, + None => -1, + }; + EngineLoad { + engine_info: engine_info.clone(), + load, + } + } + pub fn to_json(&self) -> serde_json::Value { + json!({ + "engine": self.engine_info.to_string(), + "load": self.load, + }) + } + + pub fn to_string(&self) -> String { + format!("{}: {}", self.engine_info.to_string(), self.load) + } +} + +#[derive(Debug, Clone)] +pub enum LBPolicy { + Random, + PowerOfTwo, +} + +#[derive(Debug, Clone)] +pub struct StrategyLB { + pub policy: LBPolicy, + pub prefill_servers: Vec, + pub decode_servers: Vec, +} + +impl StrategyLB { + pub fn new( + policy: LBPolicy, + prefill_servers: Vec, + decode_servers: Vec, + ) -> Self { + StrategyLB { + policy, + prefill_servers, + decode_servers, + } + } + + pub fn get_one_server(&self) -> EngineInfo { + assert!(!self.prefill_servers.is_empty()); + assert!(!self.decode_servers.is_empty()); + self.prefill_servers[0].clone() + } + + pub fn get_all_servers(&self) -> Vec { + let mut all_servers = Vec::new(); + all_servers.extend(self.prefill_servers.clone()); + all_servers.extend(self.decode_servers.clone()); + all_servers + } + + pub async fn select_pair(&self, client: &reqwest::Client) -> (EngineInfo, EngineInfo) { + match self.policy { + LBPolicy::Random => self.select_pd_pair_random(), + LBPolicy::PowerOfTwo => self.select_pd_pair_po2(client).await, + } + } + + fn select_pd_pair_random(&self) -> (EngineInfo, EngineInfo) { + let mut rng = rand::rng(); + let prefill_index = rng.random_range(0..self.prefill_servers.len()); + let decode_index = rng.random_range(0..self.decode_servers.len()); + + ( + self.prefill_servers[prefill_index].clone(), + self.decode_servers[decode_index].clone(), + ) + } + + async fn get_load_from_engine( + &self, + client: &reqwest::Client, + engine_info: &EngineInfo, + ) -> Option { + let url = engine_info.api_path("/get_load"); + let response = client.get(url).send().await.unwrap(); + match response.status() { + reqwest::StatusCode::OK => { + let data = response.json::().await.unwrap(); + Some(data["load"].as_i64().unwrap() as isize) + } + _ => None, + } + } + + async fn select_pd_pair_po2(&self, client: &reqwest::Client) -> (EngineInfo, EngineInfo) { + let mut rng = rand::rng(); + let prefill1 = + self.prefill_servers[rng.random_range(0..self.prefill_servers.len())].clone(); + let prefill2 = + self.prefill_servers[rng.random_range(0..self.prefill_servers.len())].clone(); + let decode1 = self.decode_servers[rng.random_range(0..self.decode_servers.len())].clone(); + let decode2 = self.decode_servers[rng.random_range(0..self.decode_servers.len())].clone(); + let prefill1_load = self.get_load_from_engine(client, &prefill1).await; + let prefill2_load = self.get_load_from_engine(client, &prefill2).await; + let decode1_load = self.get_load_from_engine(client, &decode1).await; + let decode2_load = self.get_load_from_engine(client, &decode2).await; + + ( + if prefill1_load < prefill2_load { + prefill1 + } else { + prefill2 + }, + if decode1_load < decode2_load { + decode1 + } else { + decode2 + }, + ) + } +}