PD Rust LB (PO2) (#6437)
Co-authored-by: fzyzcjy <5236035+fzyzcjy@users.noreply.github.com>
This commit is contained in:
140
python/sglang/srt/disaggregation/launch_lb.py
Normal file
140
python/sglang/srt/disaggregation/launch_lb.py
Normal file
@@ -0,0 +1,140 @@
|
|||||||
|
import argparse
|
||||||
|
import dataclasses
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class LBArgs:
|
||||||
|
rust_lb: bool = False
|
||||||
|
host: str = "0.0.0.0"
|
||||||
|
port: int = 8000
|
||||||
|
policy: str = "random"
|
||||||
|
prefill_infos: list = dataclasses.field(default_factory=list)
|
||||||
|
decode_infos: list = dataclasses.field(default_factory=list)
|
||||||
|
log_interval: int = 5
|
||||||
|
timeout: int = 600
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def add_cli_args(parser: argparse.ArgumentParser):
|
||||||
|
parser.add_argument(
|
||||||
|
"--rust-lb",
|
||||||
|
action="store_true",
|
||||||
|
help="Use Rust load balancer",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--host",
|
||||||
|
type=str,
|
||||||
|
default=LBArgs.host,
|
||||||
|
help=f"Host to bind the server (default: {LBArgs.host})",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--port",
|
||||||
|
type=int,
|
||||||
|
default=LBArgs.port,
|
||||||
|
help=f"Port to bind the server (default: {LBArgs.port})",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--policy",
|
||||||
|
type=str,
|
||||||
|
default=LBArgs.policy,
|
||||||
|
choices=["random", "po2"],
|
||||||
|
help=f"Policy to use for load balancing (default: {LBArgs.policy})",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--prefill",
|
||||||
|
type=str,
|
||||||
|
default=[],
|
||||||
|
nargs="+",
|
||||||
|
help="URLs for prefill servers",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--decode",
|
||||||
|
type=str,
|
||||||
|
default=[],
|
||||||
|
nargs="+",
|
||||||
|
help="URLs for decode servers",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--prefill-bootstrap-ports",
|
||||||
|
type=int,
|
||||||
|
nargs="+",
|
||||||
|
help="Bootstrap ports for prefill servers",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--log-interval",
|
||||||
|
type=int,
|
||||||
|
default=LBArgs.log_interval,
|
||||||
|
help=f"Log interval in seconds (default: {LBArgs.log_interval})",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--timeout",
|
||||||
|
type=int,
|
||||||
|
default=LBArgs.timeout,
|
||||||
|
help=f"Timeout in seconds (default: {LBArgs.timeout})",
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_cli_args(cls, args: argparse.Namespace) -> "LBArgs":
|
||||||
|
bootstrap_ports = args.prefill_bootstrap_ports
|
||||||
|
if bootstrap_ports is None:
|
||||||
|
bootstrap_ports = [None] * len(args.prefill)
|
||||||
|
elif len(bootstrap_ports) == 1:
|
||||||
|
bootstrap_ports = bootstrap_ports * len(args.prefill)
|
||||||
|
else:
|
||||||
|
if len(bootstrap_ports) != len(args.prefill):
|
||||||
|
raise ValueError(
|
||||||
|
"Number of prefill URLs must match number of bootstrap ports"
|
||||||
|
)
|
||||||
|
|
||||||
|
prefill_infos = [
|
||||||
|
(url, port) for url, port in zip(args.prefill, bootstrap_ports)
|
||||||
|
]
|
||||||
|
|
||||||
|
return cls(
|
||||||
|
rust_lb=args.rust_lb,
|
||||||
|
host=args.host,
|
||||||
|
port=args.port,
|
||||||
|
policy=args.policy,
|
||||||
|
prefill_infos=prefill_infos,
|
||||||
|
decode_infos=args.decode,
|
||||||
|
log_interval=args.log_interval,
|
||||||
|
timeout=args.timeout,
|
||||||
|
)
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
if not self.rust_lb:
|
||||||
|
assert (
|
||||||
|
self.policy == "random"
|
||||||
|
), "Only random policy is supported for Python load balancer"
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="PD Disaggregation Load Balancer Server"
|
||||||
|
)
|
||||||
|
LBArgs.add_cli_args(parser)
|
||||||
|
args = parser.parse_args()
|
||||||
|
lb_args = LBArgs.from_cli_args(args)
|
||||||
|
|
||||||
|
if lb_args.rust_lb:
|
||||||
|
from sgl_pdlb._rust import LoadBalancer as RustLB
|
||||||
|
|
||||||
|
RustLB(
|
||||||
|
host=lb_args.host,
|
||||||
|
port=lb_args.port,
|
||||||
|
policy=lb_args.policy,
|
||||||
|
prefill_infos=lb_args.prefill_infos,
|
||||||
|
decode_infos=lb_args.decode_infos,
|
||||||
|
log_interval=lb_args.log_interval,
|
||||||
|
timeout=lb_args.timeout,
|
||||||
|
).start()
|
||||||
|
else:
|
||||||
|
from sglang.srt.disaggregation.mini_lb import PrefillConfig, run
|
||||||
|
|
||||||
|
prefill_configs = [
|
||||||
|
PrefillConfig(url, port) for url, port in lb_args.prefill_infos
|
||||||
|
]
|
||||||
|
run(prefill_configs, lb_args.decode_infos, lb_args.host, lb_args.port)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -377,42 +377,7 @@ def run(prefill_configs, decode_addrs, host, port):
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import argparse
|
# FIXME: remove this, use the unified entry point: sglang.srt.disaggregation.launch_lb
|
||||||
|
from sglang.srt.disaggregation.launch_lb import main
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description="Mini Load Balancer Server")
|
main()
|
||||||
parser.add_argument(
|
|
||||||
"--prefill", type=str, default=[], nargs="+", help="URLs for prefill servers"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--decode", type=str, default=[], nargs="+", help="URLs for decode servers"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--prefill-bootstrap-ports",
|
|
||||||
type=int,
|
|
||||||
nargs="+",
|
|
||||||
help="Bootstrap ports for prefill servers",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--host", default="0.0.0.0", help="Host to bind the server (default: 0.0.0.0)"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--port", type=int, default=8000, help="Port to bind the server (default: 8000)"
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
bootstrap_ports = args.prefill_bootstrap_ports
|
|
||||||
if bootstrap_ports is None:
|
|
||||||
bootstrap_ports = [None] * len(args.prefill)
|
|
||||||
elif len(bootstrap_ports) == 1:
|
|
||||||
bootstrap_ports = bootstrap_ports * len(args.prefill)
|
|
||||||
else:
|
|
||||||
if len(bootstrap_ports) != len(args.prefill):
|
|
||||||
raise ValueError(
|
|
||||||
"Number of prefill URLs must match number of bootstrap ports"
|
|
||||||
)
|
|
||||||
|
|
||||||
prefill_configs = [
|
|
||||||
PrefillConfig(url, port) for url, port in zip(args.prefill, bootstrap_ports)
|
|
||||||
]
|
|
||||||
|
|
||||||
run(prefill_configs, args.decode, args.host, args.port)
|
|
||||||
|
|||||||
@@ -229,6 +229,11 @@ async def get_server_info():
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/get_load")
|
||||||
|
async def get_load():
|
||||||
|
return await _global_state.tokenizer_manager.get_load()
|
||||||
|
|
||||||
|
|
||||||
@app.api_route("/set_internal_state", methods=["POST", "PUT"])
|
@app.api_route("/set_internal_state", methods=["POST", "PUT"])
|
||||||
async def set_internal_state(obj: SetInternalStateReq, request: Request):
|
async def set_internal_state(obj: SetInternalStateReq, request: Request):
|
||||||
res = await _global_state.tokenizer_manager.set_internal_state(obj)
|
res = await _global_state.tokenizer_manager.set_internal_state(obj)
|
||||||
|
|||||||
@@ -103,7 +103,7 @@ class GenerateReqInput:
|
|||||||
|
|
||||||
# For disaggregated inference
|
# For disaggregated inference
|
||||||
bootstrap_host: Optional[Union[List[str], str]] = None
|
bootstrap_host: Optional[Union[List[str], str]] = None
|
||||||
bootstrap_port: Optional[Union[List[int], int]] = None
|
bootstrap_port: Optional[Union[List[Optional[int]], int]] = None
|
||||||
bootstrap_room: Optional[Union[List[int], int]] = None
|
bootstrap_room: Optional[Union[List[int], int]] = None
|
||||||
|
|
||||||
def contains_mm_input(self) -> bool:
|
def contains_mm_input(self) -> bool:
|
||||||
|
|||||||
@@ -1911,6 +1911,27 @@ class Scheduler(
|
|||||||
if_success = False
|
if_success = False
|
||||||
return if_success
|
return if_success
|
||||||
|
|
||||||
|
def get_load(self):
|
||||||
|
# TODO(lsyin): use dynamically maintained num_waiting_tokens
|
||||||
|
load = (
|
||||||
|
self.max_total_num_tokens
|
||||||
|
- self.token_to_kv_pool_allocator.available_size()
|
||||||
|
- self.tree_cache.evictable_size()
|
||||||
|
)
|
||||||
|
load += sum(len(req.origin_input_ids) for req in self.waiting_queue)
|
||||||
|
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
||||||
|
load += sum(
|
||||||
|
len(req.origin_input_ids)
|
||||||
|
for req in self.disagg_prefill_bootstrap_queue.queue
|
||||||
|
)
|
||||||
|
elif self.disaggregation_mode == DisaggregationMode.DECODE:
|
||||||
|
load += sum(
|
||||||
|
len(req.req.origin_input_ids)
|
||||||
|
for req in self.disagg_decode_prealloc_queue.queue
|
||||||
|
)
|
||||||
|
|
||||||
|
return load
|
||||||
|
|
||||||
def get_internal_state(self, recv_req: GetInternalStateReq):
|
def get_internal_state(self, recv_req: GetInternalStateReq):
|
||||||
ret = dict(global_server_args_dict)
|
ret = dict(global_server_args_dict)
|
||||||
ret["last_gen_throughput"] = self.last_gen_throughput
|
ret["last_gen_throughput"] = self.last_gen_throughput
|
||||||
@@ -1920,9 +1941,10 @@ class Scheduler(
|
|||||||
)
|
)
|
||||||
if RECORD_STEP_TIME:
|
if RECORD_STEP_TIME:
|
||||||
ret["step_time_dict"] = self.step_time_dict
|
ret["step_time_dict"] = self.step_time_dict
|
||||||
return GetInternalStateReqOutput(
|
|
||||||
internal_state=ret,
|
ret["load"] = self.get_load()
|
||||||
)
|
|
||||||
|
return GetInternalStateReqOutput(internal_state=ret)
|
||||||
|
|
||||||
def set_internal_state(self, recv_req: SetInternalStateReq):
|
def set_internal_state(self, recv_req: SetInternalStateReq):
|
||||||
server_args_dict = recv_req.server_args
|
server_args_dict = recv_req.server_args
|
||||||
|
|||||||
@@ -395,6 +395,9 @@ class TokenizerManager:
|
|||||||
self.server_args.disaggregation_bootstrap_port
|
self.server_args.disaggregation_bootstrap_port
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.current_load = 0
|
||||||
|
self.current_load_lock = asyncio.Lock()
|
||||||
|
|
||||||
async def generate_request(
|
async def generate_request(
|
||||||
self,
|
self,
|
||||||
obj: Union[GenerateReqInput, EmbeddingReqInput],
|
obj: Union[GenerateReqInput, EmbeddingReqInput],
|
||||||
@@ -983,6 +986,14 @@ class TokenizerManager:
|
|||||||
# Many DP ranks
|
# Many DP ranks
|
||||||
return [res.internal_state for res in responses]
|
return [res.internal_state for res in responses]
|
||||||
|
|
||||||
|
async def get_load(self) -> dict:
|
||||||
|
# TODO(lsyin): fake load report server
|
||||||
|
if not self.current_load_lock.locked():
|
||||||
|
async with self.current_load_lock:
|
||||||
|
internal_state = await self.get_internal_state()
|
||||||
|
self.current_load = internal_state[0]["load"]
|
||||||
|
return {"load": self.current_load}
|
||||||
|
|
||||||
async def set_internal_state(
|
async def set_internal_state(
|
||||||
self, obj: SetInternalStateReq
|
self, obj: SetInternalStateReq
|
||||||
) -> SetInternalStateReqOutput:
|
) -> SetInternalStateReqOutput:
|
||||||
|
|||||||
2
sgl-pdlb/.rustfmt.toml
Normal file
2
sgl-pdlb/.rustfmt.toml
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
reorder_imports = true
|
||||||
|
reorder_modules = true
|
||||||
28
sgl-pdlb/Cargo.toml
Normal file
28
sgl-pdlb/Cargo.toml
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
[package]
|
||||||
|
edition = "2024"
|
||||||
|
name = "sgl-pdlb"
|
||||||
|
version = "0.1.0"
|
||||||
|
|
||||||
|
[lib]
|
||||||
|
crate-type = ["cdylib", "rlib"]
|
||||||
|
name = "sgl_pdlb_rs"
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
actix-web = "4.11"
|
||||||
|
bytes = "1.8.0"
|
||||||
|
chrono = "0.4.38"
|
||||||
|
clap = { version = "4.4", features = ["derive"] }
|
||||||
|
dashmap = "6.1.0"
|
||||||
|
env_logger = "0.11.5"
|
||||||
|
futures = "0.3"
|
||||||
|
futures-util = "0.3"
|
||||||
|
http = "1.3.1"
|
||||||
|
log = "0.4.22"
|
||||||
|
pyo3 = { version = "0.25.0", features = ["extension-module"] }
|
||||||
|
rand = "0.9.0"
|
||||||
|
reqwest = { version = "0.12.8", features = ["stream", "blocking", "json"] }
|
||||||
|
serde = { version = "1.0", features = ["derive"] }
|
||||||
|
serde_json = "1.0"
|
||||||
|
tokio = { version = "1.34", features = ["full"] }
|
||||||
|
anyhow = "1.0.98"
|
||||||
|
typetag = "0.2.20"
|
||||||
12
sgl-pdlb/README.md
Normal file
12
sgl-pdlb/README.md
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
### Install dependencies
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install "maturin[patchelf]"
|
||||||
|
```
|
||||||
|
|
||||||
|
### Build and install
|
||||||
|
|
||||||
|
```bash
|
||||||
|
maturin develop
|
||||||
|
pip install -e .
|
||||||
|
```
|
||||||
1
sgl-pdlb/py_src/sgl_pdlb/__init__.py
Normal file
1
sgl-pdlb/py_src/sgl_pdlb/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
__version__ = "0.0.1"
|
||||||
14
sgl-pdlb/pyproject.toml
Normal file
14
sgl-pdlb/pyproject.toml
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
[build-system]
|
||||||
|
requires = ["maturin>=1.8.0"]
|
||||||
|
build-backend = "maturin"
|
||||||
|
|
||||||
|
[project]
|
||||||
|
name = "sgl_pdlb"
|
||||||
|
version = "0.0.1"
|
||||||
|
|
||||||
|
[tool.maturin]
|
||||||
|
python-source = "py_src"
|
||||||
|
module-name = "sgl_pdlb._rust"
|
||||||
|
|
||||||
|
[tool.maturin.build-backend]
|
||||||
|
features = ["pyo3/extension-module"]
|
||||||
133
sgl-pdlb/src/io_struct.rs
Normal file
133
sgl-pdlb/src/io_struct.rs
Normal file
@@ -0,0 +1,133 @@
|
|||||||
|
use crate::strategy_lb::EngineInfo;
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use serde_json::Value;
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize, Serialize)]
|
||||||
|
#[serde(untagged)]
|
||||||
|
pub enum SingleOrBatch<T> {
|
||||||
|
Single(T),
|
||||||
|
Batch(Vec<T>),
|
||||||
|
}
|
||||||
|
|
||||||
|
pub type InputIds = SingleOrBatch<Vec<i32>>;
|
||||||
|
pub type InputText = SingleOrBatch<String>;
|
||||||
|
pub type BootstrapHost = SingleOrBatch<String>;
|
||||||
|
pub type BootstrapPort = SingleOrBatch<Option<u16>>;
|
||||||
|
pub type BootstrapRoom = SingleOrBatch<u64>;
|
||||||
|
|
||||||
|
#[typetag::serde(tag = "type")]
|
||||||
|
pub trait Bootstrap {
|
||||||
|
fn is_stream(&self) -> bool;
|
||||||
|
fn get_batch_size(&self) -> Result<Option<usize>, actix_web::Error>;
|
||||||
|
fn set_bootstrap_info(
|
||||||
|
&mut self,
|
||||||
|
bootstrap_host: BootstrapHost,
|
||||||
|
bootstrap_port: BootstrapPort,
|
||||||
|
bootstrap_room: BootstrapRoom,
|
||||||
|
);
|
||||||
|
|
||||||
|
fn add_bootstrap_info(&mut self, prefill_info: &EngineInfo) -> Result<(), actix_web::Error> {
|
||||||
|
let batch_size = self.get_batch_size()?;
|
||||||
|
if let Some(batch_size) = batch_size {
|
||||||
|
self.set_bootstrap_info(
|
||||||
|
BootstrapHost::Batch(vec![prefill_info.get_hostname(); batch_size]),
|
||||||
|
BootstrapPort::Batch(vec![prefill_info.bootstrap_port; batch_size]),
|
||||||
|
BootstrapRoom::Batch((0..batch_size).map(|_| rand::random::<u64>()).collect()),
|
||||||
|
);
|
||||||
|
} else {
|
||||||
|
self.set_bootstrap_info(
|
||||||
|
BootstrapHost::Single(prefill_info.get_hostname()),
|
||||||
|
BootstrapPort::Single(prefill_info.bootstrap_port),
|
||||||
|
BootstrapRoom::Single(rand::random::<u64>()),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize, Serialize)]
|
||||||
|
pub struct GenerateReqInput {
|
||||||
|
pub text: Option<InputText>,
|
||||||
|
pub input_ids: Option<InputIds>,
|
||||||
|
#[serde(default)]
|
||||||
|
pub stream: bool,
|
||||||
|
pub bootstrap_host: Option<BootstrapHost>,
|
||||||
|
pub bootstrap_port: Option<BootstrapPort>,
|
||||||
|
pub bootstrap_room: Option<BootstrapRoom>,
|
||||||
|
|
||||||
|
#[serde(flatten)]
|
||||||
|
pub other: Value,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl GenerateReqInput {
|
||||||
|
pub fn get_batch_size(&self) -> Result<Option<usize>, actix_web::Error> {
|
||||||
|
if self.text.is_some() && self.input_ids.is_some() {
|
||||||
|
return Err(actix_web::error::ErrorBadRequest(
|
||||||
|
"Both text and input_ids are present in the request".to_string(),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
if let Some(InputText::Batch(texts)) = &self.text {
|
||||||
|
return Ok(Some(texts.len()));
|
||||||
|
}
|
||||||
|
if let Some(InputIds::Batch(ids)) = &self.input_ids {
|
||||||
|
return Ok(Some(ids.len()));
|
||||||
|
}
|
||||||
|
Ok(None)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[typetag::serde]
|
||||||
|
impl Bootstrap for GenerateReqInput {
|
||||||
|
fn is_stream(&self) -> bool {
|
||||||
|
self.stream
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_batch_size(&self) -> Result<Option<usize>, actix_web::Error> {
|
||||||
|
self.get_batch_size()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn set_bootstrap_info(
|
||||||
|
&mut self,
|
||||||
|
bootstrap_host: BootstrapHost,
|
||||||
|
bootstrap_port: BootstrapPort,
|
||||||
|
bootstrap_room: BootstrapRoom,
|
||||||
|
) {
|
||||||
|
self.bootstrap_host = Some(bootstrap_host);
|
||||||
|
self.bootstrap_port = Some(bootstrap_port);
|
||||||
|
self.bootstrap_room = Some(bootstrap_room);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize, Serialize)]
|
||||||
|
pub struct ChatReqInput {
|
||||||
|
#[serde(default)]
|
||||||
|
pub stream: bool,
|
||||||
|
pub bootstrap_host: Option<BootstrapHost>,
|
||||||
|
pub bootstrap_port: Option<BootstrapPort>,
|
||||||
|
pub bootstrap_room: Option<BootstrapRoom>,
|
||||||
|
|
||||||
|
#[serde(flatten)]
|
||||||
|
pub other: Value,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[typetag::serde]
|
||||||
|
impl Bootstrap for ChatReqInput {
|
||||||
|
fn is_stream(&self) -> bool {
|
||||||
|
self.stream
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_batch_size(&self) -> Result<Option<usize>, actix_web::Error> {
|
||||||
|
Ok(None)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn set_bootstrap_info(
|
||||||
|
&mut self,
|
||||||
|
bootstrap_host: BootstrapHost,
|
||||||
|
bootstrap_port: BootstrapPort,
|
||||||
|
bootstrap_room: BootstrapRoom,
|
||||||
|
) {
|
||||||
|
self.bootstrap_host = Some(bootstrap_host);
|
||||||
|
self.bootstrap_port = Some(bootstrap_port);
|
||||||
|
self.bootstrap_room = Some(bootstrap_room);
|
||||||
|
}
|
||||||
|
}
|
||||||
175
sgl-pdlb/src/lb_state.rs
Normal file
175
sgl-pdlb/src/lb_state.rs
Normal file
@@ -0,0 +1,175 @@
|
|||||||
|
use crate::io_struct::Bootstrap;
|
||||||
|
use crate::strategy_lb::{EngineInfo, EngineLoad, EngineType, LBPolicy, StrategyLB};
|
||||||
|
use actix_web::HttpResponse;
|
||||||
|
use bytes::Bytes;
|
||||||
|
use futures::{Stream, StreamExt, future::join_all};
|
||||||
|
use reqwest::{Method, StatusCode};
|
||||||
|
use std::pin::Pin;
|
||||||
|
|
||||||
|
pub enum ProxyResponseBody {
|
||||||
|
Full(Bytes),
|
||||||
|
Stream(Pin<Box<dyn Stream<Item = Result<Bytes, actix_web::Error>> + Send>>),
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct ProxyResponse {
|
||||||
|
pub status: StatusCode,
|
||||||
|
pub body: ProxyResponseBody,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ProxyResponse {
|
||||||
|
pub fn to_json(&self) -> Result<serde_json::Value, actix_web::Error> {
|
||||||
|
match &self.body {
|
||||||
|
ProxyResponseBody::Full(body) => Ok(serde_json::from_slice(&body)?),
|
||||||
|
ProxyResponseBody::Stream(_) => Err(actix_web::error::ErrorBadRequest(
|
||||||
|
"Stream response is not supported",
|
||||||
|
)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Into<Result<HttpResponse, actix_web::Error>> for ProxyResponse {
|
||||||
|
fn into(self) -> Result<HttpResponse, actix_web::Error> {
|
||||||
|
let status = actix_web::http::StatusCode::from_u16(self.status.as_u16()).map_err(|e| {
|
||||||
|
actix_web::error::ErrorBadGateway(format!("Invalid status code: {}", e))
|
||||||
|
})?;
|
||||||
|
match self.body {
|
||||||
|
ProxyResponseBody::Full(body) => Ok(HttpResponse::Ok().status(status).body(body)),
|
||||||
|
ProxyResponseBody::Stream(body) => Ok(HttpResponse::Ok()
|
||||||
|
.status(status)
|
||||||
|
.content_type("application/octet-stream")
|
||||||
|
.streaming(body)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct LBConfig {
|
||||||
|
pub host: String,
|
||||||
|
pub port: u16,
|
||||||
|
pub policy: String,
|
||||||
|
pub prefill_infos: Vec<(String, Option<u16>)>,
|
||||||
|
pub decode_infos: Vec<String>,
|
||||||
|
pub log_interval: u64,
|
||||||
|
pub timeout: u64,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct LBState {
|
||||||
|
pub strategy_lb: StrategyLB,
|
||||||
|
pub client: reqwest::Client,
|
||||||
|
pub log_interval: u64,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl LBState {
|
||||||
|
pub fn new(lb_config: LBConfig) -> anyhow::Result<Self> {
|
||||||
|
let client = reqwest::Client::builder()
|
||||||
|
.timeout(std::time::Duration::from_secs(lb_config.timeout))
|
||||||
|
.build()?;
|
||||||
|
let policy = match lb_config.policy.as_str() {
|
||||||
|
"random" => LBPolicy::Random,
|
||||||
|
"po2" => LBPolicy::PowerOfTwo,
|
||||||
|
_ => anyhow::bail!("Invalid policy"),
|
||||||
|
};
|
||||||
|
let prefill_servers = lb_config
|
||||||
|
.prefill_infos
|
||||||
|
.into_iter()
|
||||||
|
.map(|(url, port)| EngineInfo::new_prefill(url, port))
|
||||||
|
.collect();
|
||||||
|
let decode_servers = lb_config
|
||||||
|
.decode_infos
|
||||||
|
.into_iter()
|
||||||
|
.map(|url| EngineInfo::new_decode(url))
|
||||||
|
.collect();
|
||||||
|
let lb = StrategyLB::new(policy, prefill_servers, decode_servers);
|
||||||
|
Ok(Self {
|
||||||
|
strategy_lb: lb,
|
||||||
|
client,
|
||||||
|
log_interval: lb_config.log_interval,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn route_one(
|
||||||
|
&self,
|
||||||
|
engine_info: &EngineInfo,
|
||||||
|
method: Method,
|
||||||
|
api_path: &str,
|
||||||
|
request: Option<&serde_json::Value>,
|
||||||
|
stream: bool,
|
||||||
|
) -> Result<ProxyResponse, actix_web::Error> {
|
||||||
|
let url = engine_info.api_path(api_path);
|
||||||
|
let request = request.unwrap_or(&serde_json::Value::Null);
|
||||||
|
let task = self.client.request(method, url).json(request).send();
|
||||||
|
let resp = task.await.map_err(actix_web::error::ErrorBadGateway)?;
|
||||||
|
// FIXME: handle error status code (map status code to error)
|
||||||
|
let status = resp.status();
|
||||||
|
let body = if stream {
|
||||||
|
let resp_stream = resp.bytes_stream().map(|r| {
|
||||||
|
r.map_err(actix_web::error::ErrorBadGateway)
|
||||||
|
.map(Bytes::from)
|
||||||
|
});
|
||||||
|
ProxyResponseBody::Stream(Box::pin(resp_stream))
|
||||||
|
} else {
|
||||||
|
let body = resp
|
||||||
|
.bytes()
|
||||||
|
.await
|
||||||
|
.map_err(actix_web::error::ErrorBadGateway)?;
|
||||||
|
ProxyResponseBody::Full(body)
|
||||||
|
};
|
||||||
|
Ok(ProxyResponse { status, body })
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn route_collect(
|
||||||
|
&self,
|
||||||
|
engines: &Vec<EngineInfo>,
|
||||||
|
method: Method,
|
||||||
|
api_path: &str,
|
||||||
|
request: Option<&serde_json::Value>,
|
||||||
|
) -> Result<Vec<ProxyResponse>, actix_web::Error> {
|
||||||
|
let tasks = engines
|
||||||
|
.iter()
|
||||||
|
.map(|engine| self.route_one(engine, method.clone(), api_path, request, false));
|
||||||
|
let responses = join_all(tasks).await;
|
||||||
|
responses
|
||||||
|
.into_iter()
|
||||||
|
.map(|r| r.map_err(actix_web::error::ErrorBadGateway))
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn generate(
|
||||||
|
&self,
|
||||||
|
api_path: &str,
|
||||||
|
mut req: Box<dyn Bootstrap>,
|
||||||
|
) -> Result<HttpResponse, actix_web::Error> {
|
||||||
|
let (prefill, decode) = self.strategy_lb.select_pair(&self.client).await;
|
||||||
|
let stream = req.is_stream();
|
||||||
|
req.add_bootstrap_info(&prefill)?;
|
||||||
|
let json = serde_json::to_value(req)?;
|
||||||
|
let prefill_task = self.route_one(&prefill, Method::POST, api_path, Some(&json), false);
|
||||||
|
let decode_task = self.route_one(&decode, Method::POST, api_path, Some(&json), stream);
|
||||||
|
let (_, decode_response) = tokio::join!(prefill_task, decode_task);
|
||||||
|
decode_response?.into()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn get_engine_loads(
|
||||||
|
&self,
|
||||||
|
) -> Result<(Vec<EngineLoad>, Vec<EngineLoad>), actix_web::Error> {
|
||||||
|
let servers = self.strategy_lb.get_all_servers();
|
||||||
|
let responses = self
|
||||||
|
.route_collect(&servers, Method::GET, "/get_load", None)
|
||||||
|
.await?;
|
||||||
|
let loads = responses
|
||||||
|
.into_iter()
|
||||||
|
.enumerate()
|
||||||
|
.map(|(i, r)| Ok(EngineLoad::from_json(&servers[i], &r.to_json()?)))
|
||||||
|
.collect::<Result<Vec<EngineLoad>, actix_web::Error>>()?;
|
||||||
|
let mut prefill_loads = Vec::new();
|
||||||
|
let mut decode_loads = Vec::new();
|
||||||
|
for load in loads {
|
||||||
|
match load.engine_info.engine_type {
|
||||||
|
EngineType::Prefill => prefill_loads.push(load),
|
||||||
|
EngineType::Decode => decode_loads.push(load),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok((prefill_loads, decode_loads))
|
||||||
|
}
|
||||||
|
}
|
||||||
68
sgl-pdlb/src/lib.rs
Normal file
68
sgl-pdlb/src/lib.rs
Normal file
@@ -0,0 +1,68 @@
|
|||||||
|
pub mod io_struct;
|
||||||
|
pub mod lb_state;
|
||||||
|
pub mod server;
|
||||||
|
pub mod strategy_lb;
|
||||||
|
use pyo3::{exceptions::PyRuntimeError, prelude::*};
|
||||||
|
|
||||||
|
use lb_state::{LBConfig, LBState};
|
||||||
|
use server::{periodic_logging, startup};
|
||||||
|
use tokio::signal;
|
||||||
|
|
||||||
|
#[pyclass]
|
||||||
|
pub struct LoadBalancer {
|
||||||
|
lb_config: LBConfig,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[pymethods]
|
||||||
|
impl LoadBalancer {
|
||||||
|
#[new]
|
||||||
|
pub fn new(
|
||||||
|
host: String,
|
||||||
|
port: u16,
|
||||||
|
policy: String,
|
||||||
|
prefill_infos: Vec<(String, Option<u16>)>,
|
||||||
|
decode_infos: Vec<String>,
|
||||||
|
log_interval: u64,
|
||||||
|
timeout: u64,
|
||||||
|
) -> PyResult<Self> {
|
||||||
|
let lb_config = LBConfig {
|
||||||
|
host,
|
||||||
|
port,
|
||||||
|
policy,
|
||||||
|
prefill_infos,
|
||||||
|
decode_infos,
|
||||||
|
log_interval,
|
||||||
|
timeout,
|
||||||
|
};
|
||||||
|
Ok(LoadBalancer { lb_config })
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn start(&self) -> PyResult<()> {
|
||||||
|
let lb_state = LBState::new(self.lb_config.clone()).map_err(|e| {
|
||||||
|
PyRuntimeError::new_err(format!("Failed to build load balancer: {}", e))
|
||||||
|
})?;
|
||||||
|
|
||||||
|
let ret: PyResult<()> = actix_web::rt::System::new().block_on(async move {
|
||||||
|
tokio::select! {
|
||||||
|
_ = periodic_logging(lb_state.clone()) => {
|
||||||
|
unreachable!()
|
||||||
|
}
|
||||||
|
res = startup(self.lb_config.clone(), lb_state) => {
|
||||||
|
res.map_err(|e| PyRuntimeError::new_err(e.to_string()))?;
|
||||||
|
unreachable!()
|
||||||
|
}
|
||||||
|
_ = signal::ctrl_c() => {
|
||||||
|
println!("Received Ctrl+C, shutting down");
|
||||||
|
std::process::exit(0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
ret
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[pymodule]
|
||||||
|
fn _rust(_py: Python, m: &Bound<PyModule>) -> PyResult<()> {
|
||||||
|
m.add_class::<LoadBalancer>()?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
46
sgl-pdlb/src/main.rs
Normal file
46
sgl-pdlb/src/main.rs
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
mod io_struct;
|
||||||
|
mod lb_state;
|
||||||
|
mod server;
|
||||||
|
mod strategy_lb;
|
||||||
|
|
||||||
|
use lb_state::{LBConfig, LBState};
|
||||||
|
use server::{periodic_logging, startup};
|
||||||
|
use tokio::signal;
|
||||||
|
|
||||||
|
fn main() -> anyhow::Result<()> {
|
||||||
|
// FIXME: test code, move to test folder
|
||||||
|
let prefill_infos = (0..8)
|
||||||
|
.map(|i| (format!("123.123.123.123:{}", i), None))
|
||||||
|
.collect::<Vec<(String, Option<u16>)>>();
|
||||||
|
|
||||||
|
let decode_infos = (0..32)
|
||||||
|
.map(|i| format!("233.233.233.233:{}", i))
|
||||||
|
.collect::<Vec<String>>();
|
||||||
|
|
||||||
|
let lb_config = LBConfig {
|
||||||
|
host: "localhost".to_string(),
|
||||||
|
port: 8080,
|
||||||
|
policy: "random".to_string(),
|
||||||
|
prefill_infos,
|
||||||
|
decode_infos,
|
||||||
|
log_interval: 5,
|
||||||
|
timeout: 600,
|
||||||
|
};
|
||||||
|
let lb_state = LBState::new(lb_config.clone()).map_err(|e| anyhow::anyhow!(e))?;
|
||||||
|
let ret: anyhow::Result<()> = actix_web::rt::System::new().block_on(async move {
|
||||||
|
tokio::select! {
|
||||||
|
_ = periodic_logging(lb_state.clone()) => {
|
||||||
|
unreachable!()
|
||||||
|
}
|
||||||
|
res = startup(lb_config.clone(), lb_state) => {
|
||||||
|
res.map_err(|e| anyhow::anyhow!(e))?;
|
||||||
|
unreachable!()
|
||||||
|
}
|
||||||
|
_ = signal::ctrl_c() => {
|
||||||
|
println!("Received Ctrl+C, shutting down");
|
||||||
|
std::process::exit(0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
ret
|
||||||
|
}
|
||||||
171
sgl-pdlb/src/server.rs
Normal file
171
sgl-pdlb/src/server.rs
Normal file
@@ -0,0 +1,171 @@
|
|||||||
|
use crate::io_struct::{ChatReqInput, GenerateReqInput};
|
||||||
|
use crate::lb_state::{LBConfig, LBState};
|
||||||
|
use crate::strategy_lb::EngineType;
|
||||||
|
use actix_web::{HttpRequest, HttpResponse, HttpServer, get, post, web};
|
||||||
|
use reqwest::Method;
|
||||||
|
use serde_json::json;
|
||||||
|
use std::io::Write;
|
||||||
|
|
||||||
|
#[get("/health")]
|
||||||
|
pub async fn health(_req: HttpRequest, _: web::Data<LBState>) -> HttpResponse {
|
||||||
|
HttpResponse::Ok().body("Ok")
|
||||||
|
}
|
||||||
|
|
||||||
|
#[get("/health_generate")]
|
||||||
|
pub async fn health_generate(
|
||||||
|
_req: HttpRequest,
|
||||||
|
app_state: web::Data<LBState>,
|
||||||
|
) -> Result<HttpResponse, actix_web::Error> {
|
||||||
|
let servers = app_state.strategy_lb.get_all_servers();
|
||||||
|
app_state
|
||||||
|
.route_collect(&servers, Method::GET, "/health_generate", None)
|
||||||
|
.await?;
|
||||||
|
// FIXME: log the response
|
||||||
|
Ok(HttpResponse::Ok().body("Health check passed on all servers"))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[post("/flush_cache")]
|
||||||
|
pub async fn flush_cache(
|
||||||
|
_req: HttpRequest,
|
||||||
|
app_state: web::Data<LBState>,
|
||||||
|
) -> Result<HttpResponse, actix_web::Error> {
|
||||||
|
let servers = app_state.strategy_lb.get_all_servers();
|
||||||
|
app_state
|
||||||
|
.route_collect(&servers, Method::POST, "/flush_cache", None)
|
||||||
|
.await?;
|
||||||
|
Ok(HttpResponse::Ok().body("Cache flushed on all servers"))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[get("/get_model_info")]
|
||||||
|
pub async fn get_model_info(
|
||||||
|
_req: HttpRequest,
|
||||||
|
app_state: web::Data<LBState>,
|
||||||
|
) -> Result<HttpResponse, actix_web::Error> {
|
||||||
|
// Return the first server's model info
|
||||||
|
let engine = app_state.strategy_lb.get_one_server();
|
||||||
|
app_state
|
||||||
|
.route_one(&engine, Method::GET, "/get_model_info", None, false)
|
||||||
|
.await?
|
||||||
|
.into()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[post("/generate")]
|
||||||
|
pub async fn generate(
|
||||||
|
_req: HttpRequest,
|
||||||
|
req: web::Json<GenerateReqInput>,
|
||||||
|
app_state: web::Data<LBState>,
|
||||||
|
) -> Result<HttpResponse, actix_web::Error> {
|
||||||
|
app_state
|
||||||
|
.generate("/generate", Box::new(req.into_inner()))
|
||||||
|
.await
|
||||||
|
}
|
||||||
|
|
||||||
|
#[post("/v1/chat/completions")]
|
||||||
|
pub async fn chat_completions(
|
||||||
|
_req: HttpRequest,
|
||||||
|
req: web::Json<ChatReqInput>,
|
||||||
|
app_state: web::Data<LBState>,
|
||||||
|
) -> Result<HttpResponse, actix_web::Error> {
|
||||||
|
app_state
|
||||||
|
.generate("/v1/chat/completions", Box::new(req.into_inner()))
|
||||||
|
.await
|
||||||
|
}
|
||||||
|
|
||||||
|
#[get("/get_server_info")]
|
||||||
|
pub async fn get_server_info(
|
||||||
|
_req: HttpRequest,
|
||||||
|
app_state: web::Data<LBState>,
|
||||||
|
) -> Result<HttpResponse, actix_web::Error> {
|
||||||
|
let servers = app_state.strategy_lb.get_all_servers();
|
||||||
|
let responses = app_state
|
||||||
|
.route_collect(&servers, Method::GET, "/get_server_info", None)
|
||||||
|
.await?;
|
||||||
|
let mut prefill_infos = Vec::new();
|
||||||
|
let mut decode_infos = Vec::new();
|
||||||
|
for (i, resp) in responses.iter().enumerate() {
|
||||||
|
let json = resp.to_json()?;
|
||||||
|
match servers[i].engine_type {
|
||||||
|
EngineType::Prefill => prefill_infos.push(json),
|
||||||
|
EngineType::Decode => decode_infos.push(json),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(HttpResponse::Ok().json(json!({
|
||||||
|
"prefill": prefill_infos,
|
||||||
|
"decode": decode_infos,
|
||||||
|
})))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[get("/get_loads")]
|
||||||
|
pub async fn get_loads(
|
||||||
|
_req: HttpRequest,
|
||||||
|
app_state: web::Data<LBState>,
|
||||||
|
) -> Result<HttpResponse, actix_web::Error> {
|
||||||
|
let (prefill_loads, decode_loads) = app_state.get_engine_loads().await?;
|
||||||
|
Ok(HttpResponse::Ok().json(json!({
|
||||||
|
"prefill": prefill_loads.into_iter().map(|l| l.to_json()).collect::<Vec<_>>(),
|
||||||
|
"decode": decode_loads.into_iter().map(|l| l.to_json()).collect::<Vec<_>>()
|
||||||
|
})))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn periodic_logging(lb_state: LBState) {
|
||||||
|
// FIXME: currently we can just clone the lb_state to log as the lb is stateless
|
||||||
|
loop {
|
||||||
|
tokio::time::sleep(std::time::Duration::from_secs(lb_state.log_interval)).await;
|
||||||
|
let (prefill_loads, decode_loads) = match lb_state.get_engine_loads().await {
|
||||||
|
Ok((prefill_loads, decode_loads)) => (prefill_loads, decode_loads),
|
||||||
|
Err(e) => {
|
||||||
|
log::error!("Failed to get engine loads: {}", e);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
let prefill_loads = prefill_loads
|
||||||
|
.into_iter()
|
||||||
|
.map(|l| l.to_string())
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
let decode_loads = decode_loads
|
||||||
|
.into_iter()
|
||||||
|
.map(|l| l.to_string())
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
log::info!("Prefill loads: {}", prefill_loads.join(", "));
|
||||||
|
log::info!("Decode loads: {}", decode_loads.join(", "));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn startup(lb_config: LBConfig, lb_state: LBState) -> std::io::Result<()> {
|
||||||
|
let app_state = web::Data::new(lb_state);
|
||||||
|
|
||||||
|
println!("Starting server at {}:{}", lb_config.host, lb_config.port);
|
||||||
|
|
||||||
|
// default level is info
|
||||||
|
env_logger::Builder::new()
|
||||||
|
.format(|buf, record| {
|
||||||
|
writeln!(
|
||||||
|
buf,
|
||||||
|
"{} - {} - {}",
|
||||||
|
chrono::Local::now().format("%Y-%m-%d %H:%M:%S"),
|
||||||
|
record.level(),
|
||||||
|
record.args()
|
||||||
|
)
|
||||||
|
})
|
||||||
|
.filter(None, log::LevelFilter::Info)
|
||||||
|
.init();
|
||||||
|
|
||||||
|
HttpServer::new(move || {
|
||||||
|
actix_web::App::new()
|
||||||
|
.wrap(actix_web::middleware::Logger::default())
|
||||||
|
.app_data(app_state.clone())
|
||||||
|
.service(health)
|
||||||
|
.service(health_generate)
|
||||||
|
.service(flush_cache)
|
||||||
|
.service(get_model_info)
|
||||||
|
.service(get_server_info)
|
||||||
|
.service(get_loads)
|
||||||
|
.service(generate)
|
||||||
|
.service(chat_completions)
|
||||||
|
})
|
||||||
|
.bind((lb_config.host, lb_config.port))?
|
||||||
|
.run()
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
std::io::Result::Ok(())
|
||||||
|
}
|
||||||
182
sgl-pdlb/src/strategy_lb.rs
Normal file
182
sgl-pdlb/src/strategy_lb.rs
Normal file
@@ -0,0 +1,182 @@
|
|||||||
|
use rand::Rng;
|
||||||
|
use serde_json::json;
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub enum EngineType {
|
||||||
|
Prefill,
|
||||||
|
Decode,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct EngineInfo {
|
||||||
|
pub engine_type: EngineType,
|
||||||
|
pub url: String,
|
||||||
|
pub bootstrap_port: Option<u16>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl EngineInfo {
|
||||||
|
pub fn new_prefill(url: String, bootstrap_port: Option<u16>) -> Self {
|
||||||
|
EngineInfo {
|
||||||
|
engine_type: EngineType::Prefill,
|
||||||
|
url,
|
||||||
|
bootstrap_port,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn new_decode(url: String) -> Self {
|
||||||
|
EngineInfo {
|
||||||
|
engine_type: EngineType::Decode,
|
||||||
|
url,
|
||||||
|
bootstrap_port: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn api_path(&self, api_path: &str) -> String {
|
||||||
|
if api_path.starts_with("/") {
|
||||||
|
format!("{}{}", self.url, api_path)
|
||||||
|
} else {
|
||||||
|
format!("{}/{}", self.url, api_path)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn to_string(&self) -> String {
|
||||||
|
format!("({:?}@{})", self.engine_type, self.url)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get_hostname(&self) -> String {
|
||||||
|
let url = self
|
||||||
|
.url
|
||||||
|
.trim_start_matches("http://")
|
||||||
|
.trim_start_matches("https://");
|
||||||
|
url.split(':').next().unwrap().to_string()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct EngineLoad {
|
||||||
|
pub engine_info: EngineInfo,
|
||||||
|
pub load: isize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl EngineLoad {
|
||||||
|
pub fn from_json(engine_info: &EngineInfo, json: &serde_json::Value) -> Self {
|
||||||
|
let load = match json.get("load") {
|
||||||
|
Some(load) => load.as_i64().unwrap_or(-1) as isize,
|
||||||
|
None => -1,
|
||||||
|
};
|
||||||
|
EngineLoad {
|
||||||
|
engine_info: engine_info.clone(),
|
||||||
|
load,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
pub fn to_json(&self) -> serde_json::Value {
|
||||||
|
json!({
|
||||||
|
"engine": self.engine_info.to_string(),
|
||||||
|
"load": self.load,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn to_string(&self) -> String {
|
||||||
|
format!("{}: {}", self.engine_info.to_string(), self.load)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub enum LBPolicy {
|
||||||
|
Random,
|
||||||
|
PowerOfTwo,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct StrategyLB {
|
||||||
|
pub policy: LBPolicy,
|
||||||
|
pub prefill_servers: Vec<EngineInfo>,
|
||||||
|
pub decode_servers: Vec<EngineInfo>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl StrategyLB {
|
||||||
|
pub fn new(
|
||||||
|
policy: LBPolicy,
|
||||||
|
prefill_servers: Vec<EngineInfo>,
|
||||||
|
decode_servers: Vec<EngineInfo>,
|
||||||
|
) -> Self {
|
||||||
|
StrategyLB {
|
||||||
|
policy,
|
||||||
|
prefill_servers,
|
||||||
|
decode_servers,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get_one_server(&self) -> EngineInfo {
|
||||||
|
assert!(!self.prefill_servers.is_empty());
|
||||||
|
assert!(!self.decode_servers.is_empty());
|
||||||
|
self.prefill_servers[0].clone()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get_all_servers(&self) -> Vec<EngineInfo> {
|
||||||
|
let mut all_servers = Vec::new();
|
||||||
|
all_servers.extend(self.prefill_servers.clone());
|
||||||
|
all_servers.extend(self.decode_servers.clone());
|
||||||
|
all_servers
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn select_pair(&self, client: &reqwest::Client) -> (EngineInfo, EngineInfo) {
|
||||||
|
match self.policy {
|
||||||
|
LBPolicy::Random => self.select_pd_pair_random(),
|
||||||
|
LBPolicy::PowerOfTwo => self.select_pd_pair_po2(client).await,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn select_pd_pair_random(&self) -> (EngineInfo, EngineInfo) {
|
||||||
|
let mut rng = rand::rng();
|
||||||
|
let prefill_index = rng.random_range(0..self.prefill_servers.len());
|
||||||
|
let decode_index = rng.random_range(0..self.decode_servers.len());
|
||||||
|
|
||||||
|
(
|
||||||
|
self.prefill_servers[prefill_index].clone(),
|
||||||
|
self.decode_servers[decode_index].clone(),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn get_load_from_engine(
|
||||||
|
&self,
|
||||||
|
client: &reqwest::Client,
|
||||||
|
engine_info: &EngineInfo,
|
||||||
|
) -> Option<isize> {
|
||||||
|
let url = engine_info.api_path("/get_load");
|
||||||
|
let response = client.get(url).send().await.unwrap();
|
||||||
|
match response.status() {
|
||||||
|
reqwest::StatusCode::OK => {
|
||||||
|
let data = response.json::<serde_json::Value>().await.unwrap();
|
||||||
|
Some(data["load"].as_i64().unwrap() as isize)
|
||||||
|
}
|
||||||
|
_ => None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn select_pd_pair_po2(&self, client: &reqwest::Client) -> (EngineInfo, EngineInfo) {
|
||||||
|
let mut rng = rand::rng();
|
||||||
|
let prefill1 =
|
||||||
|
self.prefill_servers[rng.random_range(0..self.prefill_servers.len())].clone();
|
||||||
|
let prefill2 =
|
||||||
|
self.prefill_servers[rng.random_range(0..self.prefill_servers.len())].clone();
|
||||||
|
let decode1 = self.decode_servers[rng.random_range(0..self.decode_servers.len())].clone();
|
||||||
|
let decode2 = self.decode_servers[rng.random_range(0..self.decode_servers.len())].clone();
|
||||||
|
let prefill1_load = self.get_load_from_engine(client, &prefill1).await;
|
||||||
|
let prefill2_load = self.get_load_from_engine(client, &prefill2).await;
|
||||||
|
let decode1_load = self.get_load_from_engine(client, &decode1).await;
|
||||||
|
let decode2_load = self.get_load_from_engine(client, &decode2).await;
|
||||||
|
|
||||||
|
(
|
||||||
|
if prefill1_load < prefill2_load {
|
||||||
|
prefill1
|
||||||
|
} else {
|
||||||
|
prefill2
|
||||||
|
},
|
||||||
|
if decode1_load < decode2_load {
|
||||||
|
decode1
|
||||||
|
} else {
|
||||||
|
decode2
|
||||||
|
},
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user