PD Rust LB (PO2) (#6437)
Co-authored-by: fzyzcjy <5236035+fzyzcjy@users.noreply.github.com>
This commit is contained in:
140
python/sglang/srt/disaggregation/launch_lb.py
Normal file
140
python/sglang/srt/disaggregation/launch_lb.py
Normal file
@@ -0,0 +1,140 @@
|
||||
import argparse
|
||||
import dataclasses
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class LBArgs:
|
||||
rust_lb: bool = False
|
||||
host: str = "0.0.0.0"
|
||||
port: int = 8000
|
||||
policy: str = "random"
|
||||
prefill_infos: list = dataclasses.field(default_factory=list)
|
||||
decode_infos: list = dataclasses.field(default_factory=list)
|
||||
log_interval: int = 5
|
||||
timeout: int = 600
|
||||
|
||||
@staticmethod
|
||||
def add_cli_args(parser: argparse.ArgumentParser):
|
||||
parser.add_argument(
|
||||
"--rust-lb",
|
||||
action="store_true",
|
||||
help="Use Rust load balancer",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--host",
|
||||
type=str,
|
||||
default=LBArgs.host,
|
||||
help=f"Host to bind the server (default: {LBArgs.host})",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--port",
|
||||
type=int,
|
||||
default=LBArgs.port,
|
||||
help=f"Port to bind the server (default: {LBArgs.port})",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--policy",
|
||||
type=str,
|
||||
default=LBArgs.policy,
|
||||
choices=["random", "po2"],
|
||||
help=f"Policy to use for load balancing (default: {LBArgs.policy})",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--prefill",
|
||||
type=str,
|
||||
default=[],
|
||||
nargs="+",
|
||||
help="URLs for prefill servers",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--decode",
|
||||
type=str,
|
||||
default=[],
|
||||
nargs="+",
|
||||
help="URLs for decode servers",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--prefill-bootstrap-ports",
|
||||
type=int,
|
||||
nargs="+",
|
||||
help="Bootstrap ports for prefill servers",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--log-interval",
|
||||
type=int,
|
||||
default=LBArgs.log_interval,
|
||||
help=f"Log interval in seconds (default: {LBArgs.log_interval})",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--timeout",
|
||||
type=int,
|
||||
default=LBArgs.timeout,
|
||||
help=f"Timeout in seconds (default: {LBArgs.timeout})",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_cli_args(cls, args: argparse.Namespace) -> "LBArgs":
|
||||
bootstrap_ports = args.prefill_bootstrap_ports
|
||||
if bootstrap_ports is None:
|
||||
bootstrap_ports = [None] * len(args.prefill)
|
||||
elif len(bootstrap_ports) == 1:
|
||||
bootstrap_ports = bootstrap_ports * len(args.prefill)
|
||||
else:
|
||||
if len(bootstrap_ports) != len(args.prefill):
|
||||
raise ValueError(
|
||||
"Number of prefill URLs must match number of bootstrap ports"
|
||||
)
|
||||
|
||||
prefill_infos = [
|
||||
(url, port) for url, port in zip(args.prefill, bootstrap_ports)
|
||||
]
|
||||
|
||||
return cls(
|
||||
rust_lb=args.rust_lb,
|
||||
host=args.host,
|
||||
port=args.port,
|
||||
policy=args.policy,
|
||||
prefill_infos=prefill_infos,
|
||||
decode_infos=args.decode,
|
||||
log_interval=args.log_interval,
|
||||
timeout=args.timeout,
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
if not self.rust_lb:
|
||||
assert (
|
||||
self.policy == "random"
|
||||
), "Only random policy is supported for Python load balancer"
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="PD Disaggregation Load Balancer Server"
|
||||
)
|
||||
LBArgs.add_cli_args(parser)
|
||||
args = parser.parse_args()
|
||||
lb_args = LBArgs.from_cli_args(args)
|
||||
|
||||
if lb_args.rust_lb:
|
||||
from sgl_pdlb._rust import LoadBalancer as RustLB
|
||||
|
||||
RustLB(
|
||||
host=lb_args.host,
|
||||
port=lb_args.port,
|
||||
policy=lb_args.policy,
|
||||
prefill_infos=lb_args.prefill_infos,
|
||||
decode_infos=lb_args.decode_infos,
|
||||
log_interval=lb_args.log_interval,
|
||||
timeout=lb_args.timeout,
|
||||
).start()
|
||||
else:
|
||||
from sglang.srt.disaggregation.mini_lb import PrefillConfig, run
|
||||
|
||||
prefill_configs = [
|
||||
PrefillConfig(url, port) for url, port in lb_args.prefill_infos
|
||||
]
|
||||
run(prefill_configs, lb_args.decode_infos, lb_args.host, lb_args.port)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -377,42 +377,7 @@ def run(prefill_configs, decode_addrs, host, port):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
# FIXME: remove this, use the unified entry point: sglang.srt.disaggregation.launch_lb
|
||||
from sglang.srt.disaggregation.launch_lb import main
|
||||
|
||||
parser = argparse.ArgumentParser(description="Mini Load Balancer Server")
|
||||
parser.add_argument(
|
||||
"--prefill", type=str, default=[], nargs="+", help="URLs for prefill servers"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--decode", type=str, default=[], nargs="+", help="URLs for decode servers"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--prefill-bootstrap-ports",
|
||||
type=int,
|
||||
nargs="+",
|
||||
help="Bootstrap ports for prefill servers",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--host", default="0.0.0.0", help="Host to bind the server (default: 0.0.0.0)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--port", type=int, default=8000, help="Port to bind the server (default: 8000)"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
bootstrap_ports = args.prefill_bootstrap_ports
|
||||
if bootstrap_ports is None:
|
||||
bootstrap_ports = [None] * len(args.prefill)
|
||||
elif len(bootstrap_ports) == 1:
|
||||
bootstrap_ports = bootstrap_ports * len(args.prefill)
|
||||
else:
|
||||
if len(bootstrap_ports) != len(args.prefill):
|
||||
raise ValueError(
|
||||
"Number of prefill URLs must match number of bootstrap ports"
|
||||
)
|
||||
|
||||
prefill_configs = [
|
||||
PrefillConfig(url, port) for url, port in zip(args.prefill, bootstrap_ports)
|
||||
]
|
||||
|
||||
run(prefill_configs, args.decode, args.host, args.port)
|
||||
main()
|
||||
|
||||
@@ -229,6 +229,11 @@ async def get_server_info():
|
||||
}
|
||||
|
||||
|
||||
@app.get("/get_load")
|
||||
async def get_load():
|
||||
return await _global_state.tokenizer_manager.get_load()
|
||||
|
||||
|
||||
@app.api_route("/set_internal_state", methods=["POST", "PUT"])
|
||||
async def set_internal_state(obj: SetInternalStateReq, request: Request):
|
||||
res = await _global_state.tokenizer_manager.set_internal_state(obj)
|
||||
|
||||
@@ -103,7 +103,7 @@ class GenerateReqInput:
|
||||
|
||||
# For disaggregated inference
|
||||
bootstrap_host: Optional[Union[List[str], str]] = None
|
||||
bootstrap_port: Optional[Union[List[int], int]] = None
|
||||
bootstrap_port: Optional[Union[List[Optional[int]], int]] = None
|
||||
bootstrap_room: Optional[Union[List[int], int]] = None
|
||||
|
||||
def contains_mm_input(self) -> bool:
|
||||
|
||||
@@ -1911,6 +1911,27 @@ class Scheduler(
|
||||
if_success = False
|
||||
return if_success
|
||||
|
||||
def get_load(self):
|
||||
# TODO(lsyin): use dynamically maintained num_waiting_tokens
|
||||
load = (
|
||||
self.max_total_num_tokens
|
||||
- self.token_to_kv_pool_allocator.available_size()
|
||||
- self.tree_cache.evictable_size()
|
||||
)
|
||||
load += sum(len(req.origin_input_ids) for req in self.waiting_queue)
|
||||
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
||||
load += sum(
|
||||
len(req.origin_input_ids)
|
||||
for req in self.disagg_prefill_bootstrap_queue.queue
|
||||
)
|
||||
elif self.disaggregation_mode == DisaggregationMode.DECODE:
|
||||
load += sum(
|
||||
len(req.req.origin_input_ids)
|
||||
for req in self.disagg_decode_prealloc_queue.queue
|
||||
)
|
||||
|
||||
return load
|
||||
|
||||
def get_internal_state(self, recv_req: GetInternalStateReq):
|
||||
ret = dict(global_server_args_dict)
|
||||
ret["last_gen_throughput"] = self.last_gen_throughput
|
||||
@@ -1920,9 +1941,10 @@ class Scheduler(
|
||||
)
|
||||
if RECORD_STEP_TIME:
|
||||
ret["step_time_dict"] = self.step_time_dict
|
||||
return GetInternalStateReqOutput(
|
||||
internal_state=ret,
|
||||
)
|
||||
|
||||
ret["load"] = self.get_load()
|
||||
|
||||
return GetInternalStateReqOutput(internal_state=ret)
|
||||
|
||||
def set_internal_state(self, recv_req: SetInternalStateReq):
|
||||
server_args_dict = recv_req.server_args
|
||||
|
||||
@@ -395,6 +395,9 @@ class TokenizerManager:
|
||||
self.server_args.disaggregation_bootstrap_port
|
||||
)
|
||||
|
||||
self.current_load = 0
|
||||
self.current_load_lock = asyncio.Lock()
|
||||
|
||||
async def generate_request(
|
||||
self,
|
||||
obj: Union[GenerateReqInput, EmbeddingReqInput],
|
||||
@@ -983,6 +986,14 @@ class TokenizerManager:
|
||||
# Many DP ranks
|
||||
return [res.internal_state for res in responses]
|
||||
|
||||
async def get_load(self) -> dict:
|
||||
# TODO(lsyin): fake load report server
|
||||
if not self.current_load_lock.locked():
|
||||
async with self.current_load_lock:
|
||||
internal_state = await self.get_internal_state()
|
||||
self.current_load = internal_state[0]["load"]
|
||||
return {"load": self.current_load}
|
||||
|
||||
async def set_internal_state(
|
||||
self, obj: SetInternalStateReq
|
||||
) -> SetInternalStateReqOutput:
|
||||
|
||||
2
sgl-pdlb/.rustfmt.toml
Normal file
2
sgl-pdlb/.rustfmt.toml
Normal file
@@ -0,0 +1,2 @@
|
||||
reorder_imports = true
|
||||
reorder_modules = true
|
||||
28
sgl-pdlb/Cargo.toml
Normal file
28
sgl-pdlb/Cargo.toml
Normal file
@@ -0,0 +1,28 @@
|
||||
[package]
|
||||
edition = "2024"
|
||||
name = "sgl-pdlb"
|
||||
version = "0.1.0"
|
||||
|
||||
[lib]
|
||||
crate-type = ["cdylib", "rlib"]
|
||||
name = "sgl_pdlb_rs"
|
||||
|
||||
[dependencies]
|
||||
actix-web = "4.11"
|
||||
bytes = "1.8.0"
|
||||
chrono = "0.4.38"
|
||||
clap = { version = "4.4", features = ["derive"] }
|
||||
dashmap = "6.1.0"
|
||||
env_logger = "0.11.5"
|
||||
futures = "0.3"
|
||||
futures-util = "0.3"
|
||||
http = "1.3.1"
|
||||
log = "0.4.22"
|
||||
pyo3 = { version = "0.25.0", features = ["extension-module"] }
|
||||
rand = "0.9.0"
|
||||
reqwest = { version = "0.12.8", features = ["stream", "blocking", "json"] }
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
serde_json = "1.0"
|
||||
tokio = { version = "1.34", features = ["full"] }
|
||||
anyhow = "1.0.98"
|
||||
typetag = "0.2.20"
|
||||
12
sgl-pdlb/README.md
Normal file
12
sgl-pdlb/README.md
Normal file
@@ -0,0 +1,12 @@
|
||||
### Install dependencies
|
||||
|
||||
```bash
|
||||
pip install "maturin[patchelf]"
|
||||
```
|
||||
|
||||
### Build and install
|
||||
|
||||
```bash
|
||||
maturin develop
|
||||
pip install -e .
|
||||
```
|
||||
1
sgl-pdlb/py_src/sgl_pdlb/__init__.py
Normal file
1
sgl-pdlb/py_src/sgl_pdlb/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
__version__ = "0.0.1"
|
||||
14
sgl-pdlb/pyproject.toml
Normal file
14
sgl-pdlb/pyproject.toml
Normal file
@@ -0,0 +1,14 @@
|
||||
[build-system]
|
||||
requires = ["maturin>=1.8.0"]
|
||||
build-backend = "maturin"
|
||||
|
||||
[project]
|
||||
name = "sgl_pdlb"
|
||||
version = "0.0.1"
|
||||
|
||||
[tool.maturin]
|
||||
python-source = "py_src"
|
||||
module-name = "sgl_pdlb._rust"
|
||||
|
||||
[tool.maturin.build-backend]
|
||||
features = ["pyo3/extension-module"]
|
||||
133
sgl-pdlb/src/io_struct.rs
Normal file
133
sgl-pdlb/src/io_struct.rs
Normal file
@@ -0,0 +1,133 @@
|
||||
use crate::strategy_lb::EngineInfo;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum SingleOrBatch<T> {
|
||||
Single(T),
|
||||
Batch(Vec<T>),
|
||||
}
|
||||
|
||||
pub type InputIds = SingleOrBatch<Vec<i32>>;
|
||||
pub type InputText = SingleOrBatch<String>;
|
||||
pub type BootstrapHost = SingleOrBatch<String>;
|
||||
pub type BootstrapPort = SingleOrBatch<Option<u16>>;
|
||||
pub type BootstrapRoom = SingleOrBatch<u64>;
|
||||
|
||||
#[typetag::serde(tag = "type")]
|
||||
pub trait Bootstrap {
|
||||
fn is_stream(&self) -> bool;
|
||||
fn get_batch_size(&self) -> Result<Option<usize>, actix_web::Error>;
|
||||
fn set_bootstrap_info(
|
||||
&mut self,
|
||||
bootstrap_host: BootstrapHost,
|
||||
bootstrap_port: BootstrapPort,
|
||||
bootstrap_room: BootstrapRoom,
|
||||
);
|
||||
|
||||
fn add_bootstrap_info(&mut self, prefill_info: &EngineInfo) -> Result<(), actix_web::Error> {
|
||||
let batch_size = self.get_batch_size()?;
|
||||
if let Some(batch_size) = batch_size {
|
||||
self.set_bootstrap_info(
|
||||
BootstrapHost::Batch(vec![prefill_info.get_hostname(); batch_size]),
|
||||
BootstrapPort::Batch(vec![prefill_info.bootstrap_port; batch_size]),
|
||||
BootstrapRoom::Batch((0..batch_size).map(|_| rand::random::<u64>()).collect()),
|
||||
);
|
||||
} else {
|
||||
self.set_bootstrap_info(
|
||||
BootstrapHost::Single(prefill_info.get_hostname()),
|
||||
BootstrapPort::Single(prefill_info.bootstrap_port),
|
||||
BootstrapRoom::Single(rand::random::<u64>()),
|
||||
);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize)]
|
||||
pub struct GenerateReqInput {
|
||||
pub text: Option<InputText>,
|
||||
pub input_ids: Option<InputIds>,
|
||||
#[serde(default)]
|
||||
pub stream: bool,
|
||||
pub bootstrap_host: Option<BootstrapHost>,
|
||||
pub bootstrap_port: Option<BootstrapPort>,
|
||||
pub bootstrap_room: Option<BootstrapRoom>,
|
||||
|
||||
#[serde(flatten)]
|
||||
pub other: Value,
|
||||
}
|
||||
|
||||
impl GenerateReqInput {
|
||||
pub fn get_batch_size(&self) -> Result<Option<usize>, actix_web::Error> {
|
||||
if self.text.is_some() && self.input_ids.is_some() {
|
||||
return Err(actix_web::error::ErrorBadRequest(
|
||||
"Both text and input_ids are present in the request".to_string(),
|
||||
));
|
||||
}
|
||||
if let Some(InputText::Batch(texts)) = &self.text {
|
||||
return Ok(Some(texts.len()));
|
||||
}
|
||||
if let Some(InputIds::Batch(ids)) = &self.input_ids {
|
||||
return Ok(Some(ids.len()));
|
||||
}
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
|
||||
#[typetag::serde]
|
||||
impl Bootstrap for GenerateReqInput {
|
||||
fn is_stream(&self) -> bool {
|
||||
self.stream
|
||||
}
|
||||
|
||||
fn get_batch_size(&self) -> Result<Option<usize>, actix_web::Error> {
|
||||
self.get_batch_size()
|
||||
}
|
||||
|
||||
fn set_bootstrap_info(
|
||||
&mut self,
|
||||
bootstrap_host: BootstrapHost,
|
||||
bootstrap_port: BootstrapPort,
|
||||
bootstrap_room: BootstrapRoom,
|
||||
) {
|
||||
self.bootstrap_host = Some(bootstrap_host);
|
||||
self.bootstrap_port = Some(bootstrap_port);
|
||||
self.bootstrap_room = Some(bootstrap_room);
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize)]
|
||||
pub struct ChatReqInput {
|
||||
#[serde(default)]
|
||||
pub stream: bool,
|
||||
pub bootstrap_host: Option<BootstrapHost>,
|
||||
pub bootstrap_port: Option<BootstrapPort>,
|
||||
pub bootstrap_room: Option<BootstrapRoom>,
|
||||
|
||||
#[serde(flatten)]
|
||||
pub other: Value,
|
||||
}
|
||||
|
||||
#[typetag::serde]
|
||||
impl Bootstrap for ChatReqInput {
|
||||
fn is_stream(&self) -> bool {
|
||||
self.stream
|
||||
}
|
||||
|
||||
fn get_batch_size(&self) -> Result<Option<usize>, actix_web::Error> {
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
fn set_bootstrap_info(
|
||||
&mut self,
|
||||
bootstrap_host: BootstrapHost,
|
||||
bootstrap_port: BootstrapPort,
|
||||
bootstrap_room: BootstrapRoom,
|
||||
) {
|
||||
self.bootstrap_host = Some(bootstrap_host);
|
||||
self.bootstrap_port = Some(bootstrap_port);
|
||||
self.bootstrap_room = Some(bootstrap_room);
|
||||
}
|
||||
}
|
||||
175
sgl-pdlb/src/lb_state.rs
Normal file
175
sgl-pdlb/src/lb_state.rs
Normal file
@@ -0,0 +1,175 @@
|
||||
use crate::io_struct::Bootstrap;
|
||||
use crate::strategy_lb::{EngineInfo, EngineLoad, EngineType, LBPolicy, StrategyLB};
|
||||
use actix_web::HttpResponse;
|
||||
use bytes::Bytes;
|
||||
use futures::{Stream, StreamExt, future::join_all};
|
||||
use reqwest::{Method, StatusCode};
|
||||
use std::pin::Pin;
|
||||
|
||||
pub enum ProxyResponseBody {
|
||||
Full(Bytes),
|
||||
Stream(Pin<Box<dyn Stream<Item = Result<Bytes, actix_web::Error>> + Send>>),
|
||||
}
|
||||
|
||||
pub struct ProxyResponse {
|
||||
pub status: StatusCode,
|
||||
pub body: ProxyResponseBody,
|
||||
}
|
||||
|
||||
impl ProxyResponse {
|
||||
pub fn to_json(&self) -> Result<serde_json::Value, actix_web::Error> {
|
||||
match &self.body {
|
||||
ProxyResponseBody::Full(body) => Ok(serde_json::from_slice(&body)?),
|
||||
ProxyResponseBody::Stream(_) => Err(actix_web::error::ErrorBadRequest(
|
||||
"Stream response is not supported",
|
||||
)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Into<Result<HttpResponse, actix_web::Error>> for ProxyResponse {
|
||||
fn into(self) -> Result<HttpResponse, actix_web::Error> {
|
||||
let status = actix_web::http::StatusCode::from_u16(self.status.as_u16()).map_err(|e| {
|
||||
actix_web::error::ErrorBadGateway(format!("Invalid status code: {}", e))
|
||||
})?;
|
||||
match self.body {
|
||||
ProxyResponseBody::Full(body) => Ok(HttpResponse::Ok().status(status).body(body)),
|
||||
ProxyResponseBody::Stream(body) => Ok(HttpResponse::Ok()
|
||||
.status(status)
|
||||
.content_type("application/octet-stream")
|
||||
.streaming(body)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct LBConfig {
|
||||
pub host: String,
|
||||
pub port: u16,
|
||||
pub policy: String,
|
||||
pub prefill_infos: Vec<(String, Option<u16>)>,
|
||||
pub decode_infos: Vec<String>,
|
||||
pub log_interval: u64,
|
||||
pub timeout: u64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct LBState {
|
||||
pub strategy_lb: StrategyLB,
|
||||
pub client: reqwest::Client,
|
||||
pub log_interval: u64,
|
||||
}
|
||||
|
||||
impl LBState {
|
||||
pub fn new(lb_config: LBConfig) -> anyhow::Result<Self> {
|
||||
let client = reqwest::Client::builder()
|
||||
.timeout(std::time::Duration::from_secs(lb_config.timeout))
|
||||
.build()?;
|
||||
let policy = match lb_config.policy.as_str() {
|
||||
"random" => LBPolicy::Random,
|
||||
"po2" => LBPolicy::PowerOfTwo,
|
||||
_ => anyhow::bail!("Invalid policy"),
|
||||
};
|
||||
let prefill_servers = lb_config
|
||||
.prefill_infos
|
||||
.into_iter()
|
||||
.map(|(url, port)| EngineInfo::new_prefill(url, port))
|
||||
.collect();
|
||||
let decode_servers = lb_config
|
||||
.decode_infos
|
||||
.into_iter()
|
||||
.map(|url| EngineInfo::new_decode(url))
|
||||
.collect();
|
||||
let lb = StrategyLB::new(policy, prefill_servers, decode_servers);
|
||||
Ok(Self {
|
||||
strategy_lb: lb,
|
||||
client,
|
||||
log_interval: lb_config.log_interval,
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn route_one(
|
||||
&self,
|
||||
engine_info: &EngineInfo,
|
||||
method: Method,
|
||||
api_path: &str,
|
||||
request: Option<&serde_json::Value>,
|
||||
stream: bool,
|
||||
) -> Result<ProxyResponse, actix_web::Error> {
|
||||
let url = engine_info.api_path(api_path);
|
||||
let request = request.unwrap_or(&serde_json::Value::Null);
|
||||
let task = self.client.request(method, url).json(request).send();
|
||||
let resp = task.await.map_err(actix_web::error::ErrorBadGateway)?;
|
||||
// FIXME: handle error status code (map status code to error)
|
||||
let status = resp.status();
|
||||
let body = if stream {
|
||||
let resp_stream = resp.bytes_stream().map(|r| {
|
||||
r.map_err(actix_web::error::ErrorBadGateway)
|
||||
.map(Bytes::from)
|
||||
});
|
||||
ProxyResponseBody::Stream(Box::pin(resp_stream))
|
||||
} else {
|
||||
let body = resp
|
||||
.bytes()
|
||||
.await
|
||||
.map_err(actix_web::error::ErrorBadGateway)?;
|
||||
ProxyResponseBody::Full(body)
|
||||
};
|
||||
Ok(ProxyResponse { status, body })
|
||||
}
|
||||
|
||||
pub async fn route_collect(
|
||||
&self,
|
||||
engines: &Vec<EngineInfo>,
|
||||
method: Method,
|
||||
api_path: &str,
|
||||
request: Option<&serde_json::Value>,
|
||||
) -> Result<Vec<ProxyResponse>, actix_web::Error> {
|
||||
let tasks = engines
|
||||
.iter()
|
||||
.map(|engine| self.route_one(engine, method.clone(), api_path, request, false));
|
||||
let responses = join_all(tasks).await;
|
||||
responses
|
||||
.into_iter()
|
||||
.map(|r| r.map_err(actix_web::error::ErrorBadGateway))
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub async fn generate(
|
||||
&self,
|
||||
api_path: &str,
|
||||
mut req: Box<dyn Bootstrap>,
|
||||
) -> Result<HttpResponse, actix_web::Error> {
|
||||
let (prefill, decode) = self.strategy_lb.select_pair(&self.client).await;
|
||||
let stream = req.is_stream();
|
||||
req.add_bootstrap_info(&prefill)?;
|
||||
let json = serde_json::to_value(req)?;
|
||||
let prefill_task = self.route_one(&prefill, Method::POST, api_path, Some(&json), false);
|
||||
let decode_task = self.route_one(&decode, Method::POST, api_path, Some(&json), stream);
|
||||
let (_, decode_response) = tokio::join!(prefill_task, decode_task);
|
||||
decode_response?.into()
|
||||
}
|
||||
|
||||
pub async fn get_engine_loads(
|
||||
&self,
|
||||
) -> Result<(Vec<EngineLoad>, Vec<EngineLoad>), actix_web::Error> {
|
||||
let servers = self.strategy_lb.get_all_servers();
|
||||
let responses = self
|
||||
.route_collect(&servers, Method::GET, "/get_load", None)
|
||||
.await?;
|
||||
let loads = responses
|
||||
.into_iter()
|
||||
.enumerate()
|
||||
.map(|(i, r)| Ok(EngineLoad::from_json(&servers[i], &r.to_json()?)))
|
||||
.collect::<Result<Vec<EngineLoad>, actix_web::Error>>()?;
|
||||
let mut prefill_loads = Vec::new();
|
||||
let mut decode_loads = Vec::new();
|
||||
for load in loads {
|
||||
match load.engine_info.engine_type {
|
||||
EngineType::Prefill => prefill_loads.push(load),
|
||||
EngineType::Decode => decode_loads.push(load),
|
||||
}
|
||||
}
|
||||
Ok((prefill_loads, decode_loads))
|
||||
}
|
||||
}
|
||||
68
sgl-pdlb/src/lib.rs
Normal file
68
sgl-pdlb/src/lib.rs
Normal file
@@ -0,0 +1,68 @@
|
||||
pub mod io_struct;
|
||||
pub mod lb_state;
|
||||
pub mod server;
|
||||
pub mod strategy_lb;
|
||||
use pyo3::{exceptions::PyRuntimeError, prelude::*};
|
||||
|
||||
use lb_state::{LBConfig, LBState};
|
||||
use server::{periodic_logging, startup};
|
||||
use tokio::signal;
|
||||
|
||||
#[pyclass]
|
||||
pub struct LoadBalancer {
|
||||
lb_config: LBConfig,
|
||||
}
|
||||
|
||||
#[pymethods]
|
||||
impl LoadBalancer {
|
||||
#[new]
|
||||
pub fn new(
|
||||
host: String,
|
||||
port: u16,
|
||||
policy: String,
|
||||
prefill_infos: Vec<(String, Option<u16>)>,
|
||||
decode_infos: Vec<String>,
|
||||
log_interval: u64,
|
||||
timeout: u64,
|
||||
) -> PyResult<Self> {
|
||||
let lb_config = LBConfig {
|
||||
host,
|
||||
port,
|
||||
policy,
|
||||
prefill_infos,
|
||||
decode_infos,
|
||||
log_interval,
|
||||
timeout,
|
||||
};
|
||||
Ok(LoadBalancer { lb_config })
|
||||
}
|
||||
|
||||
pub fn start(&self) -> PyResult<()> {
|
||||
let lb_state = LBState::new(self.lb_config.clone()).map_err(|e| {
|
||||
PyRuntimeError::new_err(format!("Failed to build load balancer: {}", e))
|
||||
})?;
|
||||
|
||||
let ret: PyResult<()> = actix_web::rt::System::new().block_on(async move {
|
||||
tokio::select! {
|
||||
_ = periodic_logging(lb_state.clone()) => {
|
||||
unreachable!()
|
||||
}
|
||||
res = startup(self.lb_config.clone(), lb_state) => {
|
||||
res.map_err(|e| PyRuntimeError::new_err(e.to_string()))?;
|
||||
unreachable!()
|
||||
}
|
||||
_ = signal::ctrl_c() => {
|
||||
println!("Received Ctrl+C, shutting down");
|
||||
std::process::exit(0);
|
||||
}
|
||||
}
|
||||
});
|
||||
ret
|
||||
}
|
||||
}
|
||||
|
||||
#[pymodule]
|
||||
fn _rust(_py: Python, m: &Bound<PyModule>) -> PyResult<()> {
|
||||
m.add_class::<LoadBalancer>()?;
|
||||
Ok(())
|
||||
}
|
||||
46
sgl-pdlb/src/main.rs
Normal file
46
sgl-pdlb/src/main.rs
Normal file
@@ -0,0 +1,46 @@
|
||||
mod io_struct;
|
||||
mod lb_state;
|
||||
mod server;
|
||||
mod strategy_lb;
|
||||
|
||||
use lb_state::{LBConfig, LBState};
|
||||
use server::{periodic_logging, startup};
|
||||
use tokio::signal;
|
||||
|
||||
fn main() -> anyhow::Result<()> {
|
||||
// FIXME: test code, move to test folder
|
||||
let prefill_infos = (0..8)
|
||||
.map(|i| (format!("123.123.123.123:{}", i), None))
|
||||
.collect::<Vec<(String, Option<u16>)>>();
|
||||
|
||||
let decode_infos = (0..32)
|
||||
.map(|i| format!("233.233.233.233:{}", i))
|
||||
.collect::<Vec<String>>();
|
||||
|
||||
let lb_config = LBConfig {
|
||||
host: "localhost".to_string(),
|
||||
port: 8080,
|
||||
policy: "random".to_string(),
|
||||
prefill_infos,
|
||||
decode_infos,
|
||||
log_interval: 5,
|
||||
timeout: 600,
|
||||
};
|
||||
let lb_state = LBState::new(lb_config.clone()).map_err(|e| anyhow::anyhow!(e))?;
|
||||
let ret: anyhow::Result<()> = actix_web::rt::System::new().block_on(async move {
|
||||
tokio::select! {
|
||||
_ = periodic_logging(lb_state.clone()) => {
|
||||
unreachable!()
|
||||
}
|
||||
res = startup(lb_config.clone(), lb_state) => {
|
||||
res.map_err(|e| anyhow::anyhow!(e))?;
|
||||
unreachable!()
|
||||
}
|
||||
_ = signal::ctrl_c() => {
|
||||
println!("Received Ctrl+C, shutting down");
|
||||
std::process::exit(0);
|
||||
}
|
||||
}
|
||||
});
|
||||
ret
|
||||
}
|
||||
171
sgl-pdlb/src/server.rs
Normal file
171
sgl-pdlb/src/server.rs
Normal file
@@ -0,0 +1,171 @@
|
||||
use crate::io_struct::{ChatReqInput, GenerateReqInput};
|
||||
use crate::lb_state::{LBConfig, LBState};
|
||||
use crate::strategy_lb::EngineType;
|
||||
use actix_web::{HttpRequest, HttpResponse, HttpServer, get, post, web};
|
||||
use reqwest::Method;
|
||||
use serde_json::json;
|
||||
use std::io::Write;
|
||||
|
||||
#[get("/health")]
|
||||
pub async fn health(_req: HttpRequest, _: web::Data<LBState>) -> HttpResponse {
|
||||
HttpResponse::Ok().body("Ok")
|
||||
}
|
||||
|
||||
#[get("/health_generate")]
|
||||
pub async fn health_generate(
|
||||
_req: HttpRequest,
|
||||
app_state: web::Data<LBState>,
|
||||
) -> Result<HttpResponse, actix_web::Error> {
|
||||
let servers = app_state.strategy_lb.get_all_servers();
|
||||
app_state
|
||||
.route_collect(&servers, Method::GET, "/health_generate", None)
|
||||
.await?;
|
||||
// FIXME: log the response
|
||||
Ok(HttpResponse::Ok().body("Health check passed on all servers"))
|
||||
}
|
||||
|
||||
#[post("/flush_cache")]
|
||||
pub async fn flush_cache(
|
||||
_req: HttpRequest,
|
||||
app_state: web::Data<LBState>,
|
||||
) -> Result<HttpResponse, actix_web::Error> {
|
||||
let servers = app_state.strategy_lb.get_all_servers();
|
||||
app_state
|
||||
.route_collect(&servers, Method::POST, "/flush_cache", None)
|
||||
.await?;
|
||||
Ok(HttpResponse::Ok().body("Cache flushed on all servers"))
|
||||
}
|
||||
|
||||
#[get("/get_model_info")]
|
||||
pub async fn get_model_info(
|
||||
_req: HttpRequest,
|
||||
app_state: web::Data<LBState>,
|
||||
) -> Result<HttpResponse, actix_web::Error> {
|
||||
// Return the first server's model info
|
||||
let engine = app_state.strategy_lb.get_one_server();
|
||||
app_state
|
||||
.route_one(&engine, Method::GET, "/get_model_info", None, false)
|
||||
.await?
|
||||
.into()
|
||||
}
|
||||
|
||||
#[post("/generate")]
|
||||
pub async fn generate(
|
||||
_req: HttpRequest,
|
||||
req: web::Json<GenerateReqInput>,
|
||||
app_state: web::Data<LBState>,
|
||||
) -> Result<HttpResponse, actix_web::Error> {
|
||||
app_state
|
||||
.generate("/generate", Box::new(req.into_inner()))
|
||||
.await
|
||||
}
|
||||
|
||||
#[post("/v1/chat/completions")]
|
||||
pub async fn chat_completions(
|
||||
_req: HttpRequest,
|
||||
req: web::Json<ChatReqInput>,
|
||||
app_state: web::Data<LBState>,
|
||||
) -> Result<HttpResponse, actix_web::Error> {
|
||||
app_state
|
||||
.generate("/v1/chat/completions", Box::new(req.into_inner()))
|
||||
.await
|
||||
}
|
||||
|
||||
#[get("/get_server_info")]
|
||||
pub async fn get_server_info(
|
||||
_req: HttpRequest,
|
||||
app_state: web::Data<LBState>,
|
||||
) -> Result<HttpResponse, actix_web::Error> {
|
||||
let servers = app_state.strategy_lb.get_all_servers();
|
||||
let responses = app_state
|
||||
.route_collect(&servers, Method::GET, "/get_server_info", None)
|
||||
.await?;
|
||||
let mut prefill_infos = Vec::new();
|
||||
let mut decode_infos = Vec::new();
|
||||
for (i, resp) in responses.iter().enumerate() {
|
||||
let json = resp.to_json()?;
|
||||
match servers[i].engine_type {
|
||||
EngineType::Prefill => prefill_infos.push(json),
|
||||
EngineType::Decode => decode_infos.push(json),
|
||||
}
|
||||
}
|
||||
Ok(HttpResponse::Ok().json(json!({
|
||||
"prefill": prefill_infos,
|
||||
"decode": decode_infos,
|
||||
})))
|
||||
}
|
||||
|
||||
#[get("/get_loads")]
|
||||
pub async fn get_loads(
|
||||
_req: HttpRequest,
|
||||
app_state: web::Data<LBState>,
|
||||
) -> Result<HttpResponse, actix_web::Error> {
|
||||
let (prefill_loads, decode_loads) = app_state.get_engine_loads().await?;
|
||||
Ok(HttpResponse::Ok().json(json!({
|
||||
"prefill": prefill_loads.into_iter().map(|l| l.to_json()).collect::<Vec<_>>(),
|
||||
"decode": decode_loads.into_iter().map(|l| l.to_json()).collect::<Vec<_>>()
|
||||
})))
|
||||
}
|
||||
|
||||
pub async fn periodic_logging(lb_state: LBState) {
|
||||
// FIXME: currently we can just clone the lb_state to log as the lb is stateless
|
||||
loop {
|
||||
tokio::time::sleep(std::time::Duration::from_secs(lb_state.log_interval)).await;
|
||||
let (prefill_loads, decode_loads) = match lb_state.get_engine_loads().await {
|
||||
Ok((prefill_loads, decode_loads)) => (prefill_loads, decode_loads),
|
||||
Err(e) => {
|
||||
log::error!("Failed to get engine loads: {}", e);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
let prefill_loads = prefill_loads
|
||||
.into_iter()
|
||||
.map(|l| l.to_string())
|
||||
.collect::<Vec<_>>();
|
||||
let decode_loads = decode_loads
|
||||
.into_iter()
|
||||
.map(|l| l.to_string())
|
||||
.collect::<Vec<_>>();
|
||||
log::info!("Prefill loads: {}", prefill_loads.join(", "));
|
||||
log::info!("Decode loads: {}", decode_loads.join(", "));
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn startup(lb_config: LBConfig, lb_state: LBState) -> std::io::Result<()> {
|
||||
let app_state = web::Data::new(lb_state);
|
||||
|
||||
println!("Starting server at {}:{}", lb_config.host, lb_config.port);
|
||||
|
||||
// default level is info
|
||||
env_logger::Builder::new()
|
||||
.format(|buf, record| {
|
||||
writeln!(
|
||||
buf,
|
||||
"{} - {} - {}",
|
||||
chrono::Local::now().format("%Y-%m-%d %H:%M:%S"),
|
||||
record.level(),
|
||||
record.args()
|
||||
)
|
||||
})
|
||||
.filter(None, log::LevelFilter::Info)
|
||||
.init();
|
||||
|
||||
HttpServer::new(move || {
|
||||
actix_web::App::new()
|
||||
.wrap(actix_web::middleware::Logger::default())
|
||||
.app_data(app_state.clone())
|
||||
.service(health)
|
||||
.service(health_generate)
|
||||
.service(flush_cache)
|
||||
.service(get_model_info)
|
||||
.service(get_server_info)
|
||||
.service(get_loads)
|
||||
.service(generate)
|
||||
.service(chat_completions)
|
||||
})
|
||||
.bind((lb_config.host, lb_config.port))?
|
||||
.run()
|
||||
.await?;
|
||||
|
||||
std::io::Result::Ok(())
|
||||
}
|
||||
182
sgl-pdlb/src/strategy_lb.rs
Normal file
182
sgl-pdlb/src/strategy_lb.rs
Normal file
@@ -0,0 +1,182 @@
|
||||
use rand::Rng;
|
||||
use serde_json::json;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum EngineType {
|
||||
Prefill,
|
||||
Decode,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct EngineInfo {
|
||||
pub engine_type: EngineType,
|
||||
pub url: String,
|
||||
pub bootstrap_port: Option<u16>,
|
||||
}
|
||||
|
||||
impl EngineInfo {
|
||||
pub fn new_prefill(url: String, bootstrap_port: Option<u16>) -> Self {
|
||||
EngineInfo {
|
||||
engine_type: EngineType::Prefill,
|
||||
url,
|
||||
bootstrap_port,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn new_decode(url: String) -> Self {
|
||||
EngineInfo {
|
||||
engine_type: EngineType::Decode,
|
||||
url,
|
||||
bootstrap_port: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn api_path(&self, api_path: &str) -> String {
|
||||
if api_path.starts_with("/") {
|
||||
format!("{}{}", self.url, api_path)
|
||||
} else {
|
||||
format!("{}/{}", self.url, api_path)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn to_string(&self) -> String {
|
||||
format!("({:?}@{})", self.engine_type, self.url)
|
||||
}
|
||||
|
||||
pub fn get_hostname(&self) -> String {
|
||||
let url = self
|
||||
.url
|
||||
.trim_start_matches("http://")
|
||||
.trim_start_matches("https://");
|
||||
url.split(':').next().unwrap().to_string()
|
||||
}
|
||||
}
|
||||
|
||||
pub struct EngineLoad {
|
||||
pub engine_info: EngineInfo,
|
||||
pub load: isize,
|
||||
}
|
||||
|
||||
impl EngineLoad {
|
||||
pub fn from_json(engine_info: &EngineInfo, json: &serde_json::Value) -> Self {
|
||||
let load = match json.get("load") {
|
||||
Some(load) => load.as_i64().unwrap_or(-1) as isize,
|
||||
None => -1,
|
||||
};
|
||||
EngineLoad {
|
||||
engine_info: engine_info.clone(),
|
||||
load,
|
||||
}
|
||||
}
|
||||
pub fn to_json(&self) -> serde_json::Value {
|
||||
json!({
|
||||
"engine": self.engine_info.to_string(),
|
||||
"load": self.load,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn to_string(&self) -> String {
|
||||
format!("{}: {}", self.engine_info.to_string(), self.load)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum LBPolicy {
|
||||
Random,
|
||||
PowerOfTwo,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct StrategyLB {
|
||||
pub policy: LBPolicy,
|
||||
pub prefill_servers: Vec<EngineInfo>,
|
||||
pub decode_servers: Vec<EngineInfo>,
|
||||
}
|
||||
|
||||
impl StrategyLB {
|
||||
pub fn new(
|
||||
policy: LBPolicy,
|
||||
prefill_servers: Vec<EngineInfo>,
|
||||
decode_servers: Vec<EngineInfo>,
|
||||
) -> Self {
|
||||
StrategyLB {
|
||||
policy,
|
||||
prefill_servers,
|
||||
decode_servers,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_one_server(&self) -> EngineInfo {
|
||||
assert!(!self.prefill_servers.is_empty());
|
||||
assert!(!self.decode_servers.is_empty());
|
||||
self.prefill_servers[0].clone()
|
||||
}
|
||||
|
||||
pub fn get_all_servers(&self) -> Vec<EngineInfo> {
|
||||
let mut all_servers = Vec::new();
|
||||
all_servers.extend(self.prefill_servers.clone());
|
||||
all_servers.extend(self.decode_servers.clone());
|
||||
all_servers
|
||||
}
|
||||
|
||||
pub async fn select_pair(&self, client: &reqwest::Client) -> (EngineInfo, EngineInfo) {
|
||||
match self.policy {
|
||||
LBPolicy::Random => self.select_pd_pair_random(),
|
||||
LBPolicy::PowerOfTwo => self.select_pd_pair_po2(client).await,
|
||||
}
|
||||
}
|
||||
|
||||
fn select_pd_pair_random(&self) -> (EngineInfo, EngineInfo) {
|
||||
let mut rng = rand::rng();
|
||||
let prefill_index = rng.random_range(0..self.prefill_servers.len());
|
||||
let decode_index = rng.random_range(0..self.decode_servers.len());
|
||||
|
||||
(
|
||||
self.prefill_servers[prefill_index].clone(),
|
||||
self.decode_servers[decode_index].clone(),
|
||||
)
|
||||
}
|
||||
|
||||
async fn get_load_from_engine(
|
||||
&self,
|
||||
client: &reqwest::Client,
|
||||
engine_info: &EngineInfo,
|
||||
) -> Option<isize> {
|
||||
let url = engine_info.api_path("/get_load");
|
||||
let response = client.get(url).send().await.unwrap();
|
||||
match response.status() {
|
||||
reqwest::StatusCode::OK => {
|
||||
let data = response.json::<serde_json::Value>().await.unwrap();
|
||||
Some(data["load"].as_i64().unwrap() as isize)
|
||||
}
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
async fn select_pd_pair_po2(&self, client: &reqwest::Client) -> (EngineInfo, EngineInfo) {
|
||||
let mut rng = rand::rng();
|
||||
let prefill1 =
|
||||
self.prefill_servers[rng.random_range(0..self.prefill_servers.len())].clone();
|
||||
let prefill2 =
|
||||
self.prefill_servers[rng.random_range(0..self.prefill_servers.len())].clone();
|
||||
let decode1 = self.decode_servers[rng.random_range(0..self.decode_servers.len())].clone();
|
||||
let decode2 = self.decode_servers[rng.random_range(0..self.decode_servers.len())].clone();
|
||||
let prefill1_load = self.get_load_from_engine(client, &prefill1).await;
|
||||
let prefill2_load = self.get_load_from_engine(client, &prefill2).await;
|
||||
let decode1_load = self.get_load_from_engine(client, &decode1).await;
|
||||
let decode2_load = self.get_load_from_engine(client, &decode2).await;
|
||||
|
||||
(
|
||||
if prefill1_load < prefill2_load {
|
||||
prefill1
|
||||
} else {
|
||||
prefill2
|
||||
},
|
||||
if decode1_load < decode2_load {
|
||||
decode1
|
||||
} else {
|
||||
decode2
|
||||
},
|
||||
)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user