[misc] remove pdlb rust (#7796)
This commit is contained in:
@@ -1,2 +0,0 @@
|
|||||||
reorder_imports = true
|
|
||||||
reorder_modules = true
|
|
||||||
@@ -1,28 +0,0 @@
|
|||||||
[package]
|
|
||||||
edition = "2024"
|
|
||||||
name = "sgl-pdlb"
|
|
||||||
version = "0.1.0"
|
|
||||||
|
|
||||||
[lib]
|
|
||||||
crate-type = ["cdylib", "rlib"]
|
|
||||||
name = "sgl_pdlb_rs"
|
|
||||||
|
|
||||||
[dependencies]
|
|
||||||
actix-web = "4.11"
|
|
||||||
bytes = "1.8.0"
|
|
||||||
chrono = "0.4.38"
|
|
||||||
clap = { version = "4.4", features = ["derive"] }
|
|
||||||
dashmap = "6.1.0"
|
|
||||||
env_logger = "0.11.5"
|
|
||||||
futures = "0.3"
|
|
||||||
futures-util = "0.3"
|
|
||||||
http = "1.3.1"
|
|
||||||
log = "0.4.22"
|
|
||||||
pyo3 = { version = "0.25.0", features = ["extension-module"] }
|
|
||||||
rand = "0.9.0"
|
|
||||||
reqwest = { version = "0.12.8", features = ["stream", "blocking", "json"] }
|
|
||||||
serde = { version = "1.0", features = ["derive"] }
|
|
||||||
serde_json = "1.0"
|
|
||||||
tokio = { version = "1.34", features = ["full"] }
|
|
||||||
anyhow = "1.0.98"
|
|
||||||
typetag = "0.2.20"
|
|
||||||
@@ -1,12 +0,0 @@
|
|||||||
### Install dependencies
|
|
||||||
|
|
||||||
```bash
|
|
||||||
pip install "maturin[patchelf]"
|
|
||||||
```
|
|
||||||
|
|
||||||
### Build and install
|
|
||||||
|
|
||||||
```bash
|
|
||||||
maturin develop
|
|
||||||
pip install -e .
|
|
||||||
```
|
|
||||||
@@ -1 +0,0 @@
|
|||||||
__version__ = "0.0.1"
|
|
||||||
@@ -1,14 +0,0 @@
|
|||||||
[build-system]
|
|
||||||
requires = ["maturin>=1.8.0"]
|
|
||||||
build-backend = "maturin"
|
|
||||||
|
|
||||||
[project]
|
|
||||||
name = "sgl_pdlb"
|
|
||||||
version = "0.0.1"
|
|
||||||
|
|
||||||
[tool.maturin]
|
|
||||||
python-source = "py_src"
|
|
||||||
module-name = "sgl_pdlb._rust"
|
|
||||||
|
|
||||||
[tool.maturin.build-backend]
|
|
||||||
features = ["pyo3/extension-module"]
|
|
||||||
@@ -1,133 +0,0 @@
|
|||||||
use crate::strategy_lb::EngineInfo;
|
|
||||||
use serde::{Deserialize, Serialize};
|
|
||||||
use serde_json::Value;
|
|
||||||
|
|
||||||
#[derive(Debug, Deserialize, Serialize)]
|
|
||||||
#[serde(untagged)]
|
|
||||||
pub enum SingleOrBatch<T> {
|
|
||||||
Single(T),
|
|
||||||
Batch(Vec<T>),
|
|
||||||
}
|
|
||||||
|
|
||||||
pub type InputIds = SingleOrBatch<Vec<i32>>;
|
|
||||||
pub type InputText = SingleOrBatch<String>;
|
|
||||||
pub type BootstrapHost = SingleOrBatch<String>;
|
|
||||||
pub type BootstrapPort = SingleOrBatch<Option<u16>>;
|
|
||||||
pub type BootstrapRoom = SingleOrBatch<u64>;
|
|
||||||
|
|
||||||
#[typetag::serde(tag = "type")]
|
|
||||||
pub trait Bootstrap {
|
|
||||||
fn is_stream(&self) -> bool;
|
|
||||||
fn get_batch_size(&self) -> Result<Option<usize>, actix_web::Error>;
|
|
||||||
fn set_bootstrap_info(
|
|
||||||
&mut self,
|
|
||||||
bootstrap_host: BootstrapHost,
|
|
||||||
bootstrap_port: BootstrapPort,
|
|
||||||
bootstrap_room: BootstrapRoom,
|
|
||||||
);
|
|
||||||
|
|
||||||
fn add_bootstrap_info(&mut self, prefill_info: &EngineInfo) -> Result<(), actix_web::Error> {
|
|
||||||
let batch_size = self.get_batch_size()?;
|
|
||||||
if let Some(batch_size) = batch_size {
|
|
||||||
self.set_bootstrap_info(
|
|
||||||
BootstrapHost::Batch(vec![prefill_info.get_hostname(); batch_size]),
|
|
||||||
BootstrapPort::Batch(vec![prefill_info.bootstrap_port; batch_size]),
|
|
||||||
BootstrapRoom::Batch((0..batch_size).map(|_| rand::random::<u64>()).collect()),
|
|
||||||
);
|
|
||||||
} else {
|
|
||||||
self.set_bootstrap_info(
|
|
||||||
BootstrapHost::Single(prefill_info.get_hostname()),
|
|
||||||
BootstrapPort::Single(prefill_info.bootstrap_port),
|
|
||||||
BootstrapRoom::Single(rand::random::<u64>()),
|
|
||||||
);
|
|
||||||
}
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Deserialize, Serialize)]
|
|
||||||
pub struct GenerateReqInput {
|
|
||||||
pub text: Option<InputText>,
|
|
||||||
pub input_ids: Option<InputIds>,
|
|
||||||
#[serde(default)]
|
|
||||||
pub stream: bool,
|
|
||||||
pub bootstrap_host: Option<BootstrapHost>,
|
|
||||||
pub bootstrap_port: Option<BootstrapPort>,
|
|
||||||
pub bootstrap_room: Option<BootstrapRoom>,
|
|
||||||
|
|
||||||
#[serde(flatten)]
|
|
||||||
pub other: Value,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl GenerateReqInput {
|
|
||||||
pub fn get_batch_size(&self) -> Result<Option<usize>, actix_web::Error> {
|
|
||||||
if self.text.is_some() && self.input_ids.is_some() {
|
|
||||||
return Err(actix_web::error::ErrorBadRequest(
|
|
||||||
"Both text and input_ids are present in the request".to_string(),
|
|
||||||
));
|
|
||||||
}
|
|
||||||
if let Some(InputText::Batch(texts)) = &self.text {
|
|
||||||
return Ok(Some(texts.len()));
|
|
||||||
}
|
|
||||||
if let Some(InputIds::Batch(ids)) = &self.input_ids {
|
|
||||||
return Ok(Some(ids.len()));
|
|
||||||
}
|
|
||||||
Ok(None)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[typetag::serde]
|
|
||||||
impl Bootstrap for GenerateReqInput {
|
|
||||||
fn is_stream(&self) -> bool {
|
|
||||||
self.stream
|
|
||||||
}
|
|
||||||
|
|
||||||
fn get_batch_size(&self) -> Result<Option<usize>, actix_web::Error> {
|
|
||||||
self.get_batch_size()
|
|
||||||
}
|
|
||||||
|
|
||||||
fn set_bootstrap_info(
|
|
||||||
&mut self,
|
|
||||||
bootstrap_host: BootstrapHost,
|
|
||||||
bootstrap_port: BootstrapPort,
|
|
||||||
bootstrap_room: BootstrapRoom,
|
|
||||||
) {
|
|
||||||
self.bootstrap_host = Some(bootstrap_host);
|
|
||||||
self.bootstrap_port = Some(bootstrap_port);
|
|
||||||
self.bootstrap_room = Some(bootstrap_room);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Deserialize, Serialize)]
|
|
||||||
pub struct ChatReqInput {
|
|
||||||
#[serde(default)]
|
|
||||||
pub stream: bool,
|
|
||||||
pub bootstrap_host: Option<BootstrapHost>,
|
|
||||||
pub bootstrap_port: Option<BootstrapPort>,
|
|
||||||
pub bootstrap_room: Option<BootstrapRoom>,
|
|
||||||
|
|
||||||
#[serde(flatten)]
|
|
||||||
pub other: Value,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[typetag::serde]
|
|
||||||
impl Bootstrap for ChatReqInput {
|
|
||||||
fn is_stream(&self) -> bool {
|
|
||||||
self.stream
|
|
||||||
}
|
|
||||||
|
|
||||||
fn get_batch_size(&self) -> Result<Option<usize>, actix_web::Error> {
|
|
||||||
Ok(None)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn set_bootstrap_info(
|
|
||||||
&mut self,
|
|
||||||
bootstrap_host: BootstrapHost,
|
|
||||||
bootstrap_port: BootstrapPort,
|
|
||||||
bootstrap_room: BootstrapRoom,
|
|
||||||
) {
|
|
||||||
self.bootstrap_host = Some(bootstrap_host);
|
|
||||||
self.bootstrap_port = Some(bootstrap_port);
|
|
||||||
self.bootstrap_room = Some(bootstrap_room);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,175 +0,0 @@
|
|||||||
use crate::io_struct::Bootstrap;
|
|
||||||
use crate::strategy_lb::{EngineInfo, EngineLoad, EngineType, LBPolicy, StrategyLB};
|
|
||||||
use actix_web::HttpResponse;
|
|
||||||
use bytes::Bytes;
|
|
||||||
use futures::{Stream, StreamExt, future::join_all};
|
|
||||||
use reqwest::{Method, StatusCode};
|
|
||||||
use std::pin::Pin;
|
|
||||||
|
|
||||||
pub enum ProxyResponseBody {
|
|
||||||
Full(Bytes),
|
|
||||||
Stream(Pin<Box<dyn Stream<Item = Result<Bytes, actix_web::Error>> + Send>>),
|
|
||||||
}
|
|
||||||
|
|
||||||
pub struct ProxyResponse {
|
|
||||||
pub status: StatusCode,
|
|
||||||
pub body: ProxyResponseBody,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl ProxyResponse {
|
|
||||||
pub fn to_json(&self) -> Result<serde_json::Value, actix_web::Error> {
|
|
||||||
match &self.body {
|
|
||||||
ProxyResponseBody::Full(body) => Ok(serde_json::from_slice(&body)?),
|
|
||||||
ProxyResponseBody::Stream(_) => Err(actix_web::error::ErrorBadRequest(
|
|
||||||
"Stream response is not supported",
|
|
||||||
)),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Into<Result<HttpResponse, actix_web::Error>> for ProxyResponse {
|
|
||||||
fn into(self) -> Result<HttpResponse, actix_web::Error> {
|
|
||||||
let status = actix_web::http::StatusCode::from_u16(self.status.as_u16()).map_err(|e| {
|
|
||||||
actix_web::error::ErrorBadGateway(format!("Invalid status code: {}", e))
|
|
||||||
})?;
|
|
||||||
match self.body {
|
|
||||||
ProxyResponseBody::Full(body) => Ok(HttpResponse::Ok().status(status).body(body)),
|
|
||||||
ProxyResponseBody::Stream(body) => Ok(HttpResponse::Ok()
|
|
||||||
.status(status)
|
|
||||||
.content_type("application/octet-stream")
|
|
||||||
.streaming(body)),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
pub struct LBConfig {
|
|
||||||
pub host: String,
|
|
||||||
pub port: u16,
|
|
||||||
pub policy: String,
|
|
||||||
pub prefill_infos: Vec<(String, Option<u16>)>,
|
|
||||||
pub decode_infos: Vec<String>,
|
|
||||||
pub log_interval: u64,
|
|
||||||
pub timeout: u64,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
pub struct LBState {
|
|
||||||
pub strategy_lb: StrategyLB,
|
|
||||||
pub client: reqwest::Client,
|
|
||||||
pub log_interval: u64,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl LBState {
|
|
||||||
pub fn new(lb_config: LBConfig) -> anyhow::Result<Self> {
|
|
||||||
let client = reqwest::Client::builder()
|
|
||||||
.timeout(std::time::Duration::from_secs(lb_config.timeout))
|
|
||||||
.build()?;
|
|
||||||
let policy = match lb_config.policy.as_str() {
|
|
||||||
"random" => LBPolicy::Random,
|
|
||||||
"po2" => LBPolicy::PowerOfTwo,
|
|
||||||
_ => anyhow::bail!("Invalid policy"),
|
|
||||||
};
|
|
||||||
let prefill_servers = lb_config
|
|
||||||
.prefill_infos
|
|
||||||
.into_iter()
|
|
||||||
.map(|(url, port)| EngineInfo::new_prefill(url, port))
|
|
||||||
.collect();
|
|
||||||
let decode_servers = lb_config
|
|
||||||
.decode_infos
|
|
||||||
.into_iter()
|
|
||||||
.map(|url| EngineInfo::new_decode(url))
|
|
||||||
.collect();
|
|
||||||
let lb = StrategyLB::new(policy, prefill_servers, decode_servers);
|
|
||||||
Ok(Self {
|
|
||||||
strategy_lb: lb,
|
|
||||||
client,
|
|
||||||
log_interval: lb_config.log_interval,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn route_one(
|
|
||||||
&self,
|
|
||||||
engine_info: &EngineInfo,
|
|
||||||
method: Method,
|
|
||||||
api_path: &str,
|
|
||||||
request: Option<&serde_json::Value>,
|
|
||||||
stream: bool,
|
|
||||||
) -> Result<ProxyResponse, actix_web::Error> {
|
|
||||||
let url = engine_info.api_path(api_path);
|
|
||||||
let request = request.unwrap_or(&serde_json::Value::Null);
|
|
||||||
let task = self.client.request(method, url).json(request).send();
|
|
||||||
let resp = task.await.map_err(actix_web::error::ErrorBadGateway)?;
|
|
||||||
// FIXME: handle error status code (map status code to error)
|
|
||||||
let status = resp.status();
|
|
||||||
let body = if stream {
|
|
||||||
let resp_stream = resp.bytes_stream().map(|r| {
|
|
||||||
r.map_err(actix_web::error::ErrorBadGateway)
|
|
||||||
.map(Bytes::from)
|
|
||||||
});
|
|
||||||
ProxyResponseBody::Stream(Box::pin(resp_stream))
|
|
||||||
} else {
|
|
||||||
let body = resp
|
|
||||||
.bytes()
|
|
||||||
.await
|
|
||||||
.map_err(actix_web::error::ErrorBadGateway)?;
|
|
||||||
ProxyResponseBody::Full(body)
|
|
||||||
};
|
|
||||||
Ok(ProxyResponse { status, body })
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn route_collect(
|
|
||||||
&self,
|
|
||||||
engines: &Vec<EngineInfo>,
|
|
||||||
method: Method,
|
|
||||||
api_path: &str,
|
|
||||||
request: Option<&serde_json::Value>,
|
|
||||||
) -> Result<Vec<ProxyResponse>, actix_web::Error> {
|
|
||||||
let tasks = engines
|
|
||||||
.iter()
|
|
||||||
.map(|engine| self.route_one(engine, method.clone(), api_path, request, false));
|
|
||||||
let responses = join_all(tasks).await;
|
|
||||||
responses
|
|
||||||
.into_iter()
|
|
||||||
.map(|r| r.map_err(actix_web::error::ErrorBadGateway))
|
|
||||||
.collect()
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn generate(
|
|
||||||
&self,
|
|
||||||
api_path: &str,
|
|
||||||
mut req: Box<dyn Bootstrap>,
|
|
||||||
) -> Result<HttpResponse, actix_web::Error> {
|
|
||||||
let (prefill, decode) = self.strategy_lb.select_pair(&self.client).await;
|
|
||||||
let stream = req.is_stream();
|
|
||||||
req.add_bootstrap_info(&prefill)?;
|
|
||||||
let json = serde_json::to_value(req)?;
|
|
||||||
let prefill_task = self.route_one(&prefill, Method::POST, api_path, Some(&json), false);
|
|
||||||
let decode_task = self.route_one(&decode, Method::POST, api_path, Some(&json), stream);
|
|
||||||
let (_, decode_response) = tokio::join!(prefill_task, decode_task);
|
|
||||||
decode_response?.into()
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn get_engine_loads(
|
|
||||||
&self,
|
|
||||||
) -> Result<(Vec<EngineLoad>, Vec<EngineLoad>), actix_web::Error> {
|
|
||||||
let servers = self.strategy_lb.get_all_servers();
|
|
||||||
let responses = self
|
|
||||||
.route_collect(&servers, Method::GET, "/get_load", None)
|
|
||||||
.await?;
|
|
||||||
let loads = responses
|
|
||||||
.into_iter()
|
|
||||||
.enumerate()
|
|
||||||
.map(|(i, r)| Ok(EngineLoad::from_json(&servers[i], &r.to_json()?)))
|
|
||||||
.collect::<Result<Vec<EngineLoad>, actix_web::Error>>()?;
|
|
||||||
let mut prefill_loads = Vec::new();
|
|
||||||
let mut decode_loads = Vec::new();
|
|
||||||
for load in loads {
|
|
||||||
match load.engine_info.engine_type {
|
|
||||||
EngineType::Prefill => prefill_loads.push(load),
|
|
||||||
EngineType::Decode => decode_loads.push(load),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Ok((prefill_loads, decode_loads))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,68 +0,0 @@
|
|||||||
pub mod io_struct;
|
|
||||||
pub mod lb_state;
|
|
||||||
pub mod server;
|
|
||||||
pub mod strategy_lb;
|
|
||||||
use pyo3::{exceptions::PyRuntimeError, prelude::*};
|
|
||||||
|
|
||||||
use lb_state::{LBConfig, LBState};
|
|
||||||
use server::{periodic_logging, startup};
|
|
||||||
use tokio::signal;
|
|
||||||
|
|
||||||
#[pyclass]
|
|
||||||
pub struct LoadBalancer {
|
|
||||||
lb_config: LBConfig,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[pymethods]
|
|
||||||
impl LoadBalancer {
|
|
||||||
#[new]
|
|
||||||
pub fn new(
|
|
||||||
host: String,
|
|
||||||
port: u16,
|
|
||||||
policy: String,
|
|
||||||
prefill_infos: Vec<(String, Option<u16>)>,
|
|
||||||
decode_infos: Vec<String>,
|
|
||||||
log_interval: u64,
|
|
||||||
timeout: u64,
|
|
||||||
) -> PyResult<Self> {
|
|
||||||
let lb_config = LBConfig {
|
|
||||||
host,
|
|
||||||
port,
|
|
||||||
policy,
|
|
||||||
prefill_infos,
|
|
||||||
decode_infos,
|
|
||||||
log_interval,
|
|
||||||
timeout,
|
|
||||||
};
|
|
||||||
Ok(LoadBalancer { lb_config })
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn start(&self) -> PyResult<()> {
|
|
||||||
let lb_state = LBState::new(self.lb_config.clone()).map_err(|e| {
|
|
||||||
PyRuntimeError::new_err(format!("Failed to build load balancer: {}", e))
|
|
||||||
})?;
|
|
||||||
|
|
||||||
let ret: PyResult<()> = actix_web::rt::System::new().block_on(async move {
|
|
||||||
tokio::select! {
|
|
||||||
_ = periodic_logging(lb_state.clone()) => {
|
|
||||||
unreachable!()
|
|
||||||
}
|
|
||||||
res = startup(self.lb_config.clone(), lb_state) => {
|
|
||||||
res.map_err(|e| PyRuntimeError::new_err(e.to_string()))?;
|
|
||||||
unreachable!()
|
|
||||||
}
|
|
||||||
_ = signal::ctrl_c() => {
|
|
||||||
println!("Received Ctrl+C, shutting down");
|
|
||||||
std::process::exit(0);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
});
|
|
||||||
ret
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[pymodule]
|
|
||||||
fn _rust(_py: Python, m: &Bound<PyModule>) -> PyResult<()> {
|
|
||||||
m.add_class::<LoadBalancer>()?;
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
@@ -1,46 +0,0 @@
|
|||||||
mod io_struct;
|
|
||||||
mod lb_state;
|
|
||||||
mod server;
|
|
||||||
mod strategy_lb;
|
|
||||||
|
|
||||||
use lb_state::{LBConfig, LBState};
|
|
||||||
use server::{periodic_logging, startup};
|
|
||||||
use tokio::signal;
|
|
||||||
|
|
||||||
fn main() -> anyhow::Result<()> {
|
|
||||||
// FIXME: test code, move to test folder
|
|
||||||
let prefill_infos = (0..8)
|
|
||||||
.map(|i| (format!("123.123.123.123:{}", i), None))
|
|
||||||
.collect::<Vec<(String, Option<u16>)>>();
|
|
||||||
|
|
||||||
let decode_infos = (0..32)
|
|
||||||
.map(|i| format!("233.233.233.233:{}", i))
|
|
||||||
.collect::<Vec<String>>();
|
|
||||||
|
|
||||||
let lb_config = LBConfig {
|
|
||||||
host: "localhost".to_string(),
|
|
||||||
port: 8080,
|
|
||||||
policy: "random".to_string(),
|
|
||||||
prefill_infos,
|
|
||||||
decode_infos,
|
|
||||||
log_interval: 5,
|
|
||||||
timeout: 600,
|
|
||||||
};
|
|
||||||
let lb_state = LBState::new(lb_config.clone()).map_err(|e| anyhow::anyhow!(e))?;
|
|
||||||
let ret: anyhow::Result<()> = actix_web::rt::System::new().block_on(async move {
|
|
||||||
tokio::select! {
|
|
||||||
_ = periodic_logging(lb_state.clone()) => {
|
|
||||||
unreachable!()
|
|
||||||
}
|
|
||||||
res = startup(lb_config.clone(), lb_state) => {
|
|
||||||
res.map_err(|e| anyhow::anyhow!(e))?;
|
|
||||||
unreachable!()
|
|
||||||
}
|
|
||||||
_ = signal::ctrl_c() => {
|
|
||||||
println!("Received Ctrl+C, shutting down");
|
|
||||||
std::process::exit(0);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
});
|
|
||||||
ret
|
|
||||||
}
|
|
||||||
@@ -1,183 +0,0 @@
|
|||||||
use crate::io_struct::{ChatReqInput, GenerateReqInput};
|
|
||||||
use crate::lb_state::{LBConfig, LBState};
|
|
||||||
use crate::strategy_lb::EngineType;
|
|
||||||
use actix_web::{HttpRequest, HttpResponse, HttpServer, get, post, web};
|
|
||||||
use reqwest::Method;
|
|
||||||
use serde_json::json;
|
|
||||||
use std::io::Write;
|
|
||||||
|
|
||||||
#[get("/health")]
|
|
||||||
pub async fn health(_req: HttpRequest, _: web::Data<LBState>) -> HttpResponse {
|
|
||||||
HttpResponse::Ok().body("Ok")
|
|
||||||
}
|
|
||||||
|
|
||||||
#[get("/health_generate")]
|
|
||||||
pub async fn health_generate(
|
|
||||||
_req: HttpRequest,
|
|
||||||
app_state: web::Data<LBState>,
|
|
||||||
) -> Result<HttpResponse, actix_web::Error> {
|
|
||||||
let servers = app_state.strategy_lb.get_all_servers();
|
|
||||||
app_state
|
|
||||||
.route_collect(&servers, Method::GET, "/health_generate", None)
|
|
||||||
.await?;
|
|
||||||
// FIXME: log the response
|
|
||||||
Ok(HttpResponse::Ok().body("Health check passed on all servers"))
|
|
||||||
}
|
|
||||||
|
|
||||||
#[post("/flush_cache")]
|
|
||||||
pub async fn flush_cache(
|
|
||||||
_req: HttpRequest,
|
|
||||||
app_state: web::Data<LBState>,
|
|
||||||
) -> Result<HttpResponse, actix_web::Error> {
|
|
||||||
let servers = app_state.strategy_lb.get_all_servers();
|
|
||||||
app_state
|
|
||||||
.route_collect(&servers, Method::POST, "/flush_cache", None)
|
|
||||||
.await?;
|
|
||||||
Ok(HttpResponse::Ok().body("Cache flushed on all servers"))
|
|
||||||
}
|
|
||||||
|
|
||||||
#[get("/get_model_info")]
|
|
||||||
pub async fn get_model_info(
|
|
||||||
_req: HttpRequest,
|
|
||||||
app_state: web::Data<LBState>,
|
|
||||||
) -> Result<HttpResponse, actix_web::Error> {
|
|
||||||
// Return the first server's model info
|
|
||||||
let engine = app_state.strategy_lb.get_one_server();
|
|
||||||
app_state
|
|
||||||
.route_one(&engine, Method::GET, "/get_model_info", None, false)
|
|
||||||
.await?
|
|
||||||
.into()
|
|
||||||
}
|
|
||||||
|
|
||||||
#[post("/generate")]
|
|
||||||
pub async fn generate(
|
|
||||||
_req: HttpRequest,
|
|
||||||
req: web::Json<GenerateReqInput>,
|
|
||||||
app_state: web::Data<LBState>,
|
|
||||||
) -> Result<HttpResponse, actix_web::Error> {
|
|
||||||
app_state
|
|
||||||
.generate("/generate", Box::new(req.into_inner()))
|
|
||||||
.await
|
|
||||||
}
|
|
||||||
|
|
||||||
#[post("/v1/completions")]
|
|
||||||
pub async fn completions(
|
|
||||||
_req: HttpRequest,
|
|
||||||
req: web::Json<GenerateReqInput>,
|
|
||||||
app_state: web::Data<LBState>,
|
|
||||||
) -> Result<HttpResponse, actix_web::Error> {
|
|
||||||
app_state
|
|
||||||
.generate("/v1/completions", Box::new(req.into_inner()))
|
|
||||||
.await
|
|
||||||
}
|
|
||||||
|
|
||||||
#[post("/v1/chat/completions")]
|
|
||||||
pub async fn chat_completions(
|
|
||||||
_req: HttpRequest,
|
|
||||||
req: web::Json<ChatReqInput>,
|
|
||||||
app_state: web::Data<LBState>,
|
|
||||||
) -> Result<HttpResponse, actix_web::Error> {
|
|
||||||
app_state
|
|
||||||
.generate("/v1/chat/completions", Box::new(req.into_inner()))
|
|
||||||
.await
|
|
||||||
}
|
|
||||||
|
|
||||||
#[get("/get_server_info")]
|
|
||||||
pub async fn get_server_info(
|
|
||||||
_req: HttpRequest,
|
|
||||||
app_state: web::Data<LBState>,
|
|
||||||
) -> Result<HttpResponse, actix_web::Error> {
|
|
||||||
let servers = app_state.strategy_lb.get_all_servers();
|
|
||||||
let responses = app_state
|
|
||||||
.route_collect(&servers, Method::GET, "/get_server_info", None)
|
|
||||||
.await?;
|
|
||||||
let mut prefill_infos = Vec::new();
|
|
||||||
let mut decode_infos = Vec::new();
|
|
||||||
for (i, resp) in responses.iter().enumerate() {
|
|
||||||
let json = resp.to_json()?;
|
|
||||||
match servers[i].engine_type {
|
|
||||||
EngineType::Prefill => prefill_infos.push(json),
|
|
||||||
EngineType::Decode => decode_infos.push(json),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Ok(HttpResponse::Ok().json(json!({
|
|
||||||
"prefill": prefill_infos,
|
|
||||||
"decode": decode_infos,
|
|
||||||
})))
|
|
||||||
}
|
|
||||||
|
|
||||||
#[get("/get_loads")]
|
|
||||||
pub async fn get_loads(
|
|
||||||
_req: HttpRequest,
|
|
||||||
app_state: web::Data<LBState>,
|
|
||||||
) -> Result<HttpResponse, actix_web::Error> {
|
|
||||||
let (prefill_loads, decode_loads) = app_state.get_engine_loads().await?;
|
|
||||||
Ok(HttpResponse::Ok().json(json!({
|
|
||||||
"prefill": prefill_loads.into_iter().map(|l| l.to_json()).collect::<Vec<_>>(),
|
|
||||||
"decode": decode_loads.into_iter().map(|l| l.to_json()).collect::<Vec<_>>()
|
|
||||||
})))
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn periodic_logging(lb_state: LBState) {
|
|
||||||
// FIXME: currently we can just clone the lb_state to log as the lb is stateless
|
|
||||||
loop {
|
|
||||||
tokio::time::sleep(std::time::Duration::from_secs(lb_state.log_interval)).await;
|
|
||||||
let (prefill_loads, decode_loads) = match lb_state.get_engine_loads().await {
|
|
||||||
Ok((prefill_loads, decode_loads)) => (prefill_loads, decode_loads),
|
|
||||||
Err(e) => {
|
|
||||||
log::error!("Failed to get engine loads: {}", e);
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
let prefill_loads = prefill_loads
|
|
||||||
.into_iter()
|
|
||||||
.map(|l| l.to_string())
|
|
||||||
.collect::<Vec<_>>();
|
|
||||||
let decode_loads = decode_loads
|
|
||||||
.into_iter()
|
|
||||||
.map(|l| l.to_string())
|
|
||||||
.collect::<Vec<_>>();
|
|
||||||
log::info!("Prefill loads: {}", prefill_loads.join(", "));
|
|
||||||
log::info!("Decode loads: {}", decode_loads.join(", "));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn startup(lb_config: LBConfig, lb_state: LBState) -> std::io::Result<()> {
|
|
||||||
let app_state = web::Data::new(lb_state);
|
|
||||||
|
|
||||||
println!("Starting server at {}:{}", lb_config.host, lb_config.port);
|
|
||||||
|
|
||||||
// default level is info
|
|
||||||
env_logger::Builder::new()
|
|
||||||
.format(|buf, record| {
|
|
||||||
writeln!(
|
|
||||||
buf,
|
|
||||||
"{} - {} - {}",
|
|
||||||
chrono::Local::now().format("%Y-%m-%d %H:%M:%S"),
|
|
||||||
record.level(),
|
|
||||||
record.args()
|
|
||||||
)
|
|
||||||
})
|
|
||||||
.filter(None, log::LevelFilter::Info)
|
|
||||||
.init();
|
|
||||||
|
|
||||||
HttpServer::new(move || {
|
|
||||||
actix_web::App::new()
|
|
||||||
.wrap(actix_web::middleware::Logger::default())
|
|
||||||
.app_data(app_state.clone())
|
|
||||||
.service(health)
|
|
||||||
.service(health_generate)
|
|
||||||
.service(flush_cache)
|
|
||||||
.service(get_model_info)
|
|
||||||
.service(get_server_info)
|
|
||||||
.service(get_loads)
|
|
||||||
.service(generate)
|
|
||||||
.service(chat_completions)
|
|
||||||
.service(completions)
|
|
||||||
})
|
|
||||||
.bind((lb_config.host, lb_config.port))?
|
|
||||||
.run()
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
std::io::Result::Ok(())
|
|
||||||
}
|
|
||||||
@@ -1,182 +0,0 @@
|
|||||||
use rand::Rng;
|
|
||||||
use serde_json::json;
|
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
pub enum EngineType {
|
|
||||||
Prefill,
|
|
||||||
Decode,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
pub struct EngineInfo {
|
|
||||||
pub engine_type: EngineType,
|
|
||||||
pub url: String,
|
|
||||||
pub bootstrap_port: Option<u16>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl EngineInfo {
|
|
||||||
pub fn new_prefill(url: String, bootstrap_port: Option<u16>) -> Self {
|
|
||||||
EngineInfo {
|
|
||||||
engine_type: EngineType::Prefill,
|
|
||||||
url,
|
|
||||||
bootstrap_port,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn new_decode(url: String) -> Self {
|
|
||||||
EngineInfo {
|
|
||||||
engine_type: EngineType::Decode,
|
|
||||||
url,
|
|
||||||
bootstrap_port: None,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn api_path(&self, api_path: &str) -> String {
|
|
||||||
if api_path.starts_with("/") {
|
|
||||||
format!("{}{}", self.url, api_path)
|
|
||||||
} else {
|
|
||||||
format!("{}/{}", self.url, api_path)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn to_string(&self) -> String {
|
|
||||||
format!("({:?}@{})", self.engine_type, self.url)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn get_hostname(&self) -> String {
|
|
||||||
let url = self
|
|
||||||
.url
|
|
||||||
.trim_start_matches("http://")
|
|
||||||
.trim_start_matches("https://");
|
|
||||||
url.split(':').next().unwrap().to_string()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub struct EngineLoad {
|
|
||||||
pub engine_info: EngineInfo,
|
|
||||||
pub load: isize,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl EngineLoad {
|
|
||||||
pub fn from_json(engine_info: &EngineInfo, json: &serde_json::Value) -> Self {
|
|
||||||
let load = match json.get("load") {
|
|
||||||
Some(load) => load.as_i64().unwrap_or(-1) as isize,
|
|
||||||
None => -1,
|
|
||||||
};
|
|
||||||
EngineLoad {
|
|
||||||
engine_info: engine_info.clone(),
|
|
||||||
load,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
pub fn to_json(&self) -> serde_json::Value {
|
|
||||||
json!({
|
|
||||||
"engine": self.engine_info.to_string(),
|
|
||||||
"load": self.load,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn to_string(&self) -> String {
|
|
||||||
format!("{}: {}", self.engine_info.to_string(), self.load)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
pub enum LBPolicy {
|
|
||||||
Random,
|
|
||||||
PowerOfTwo,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
pub struct StrategyLB {
|
|
||||||
pub policy: LBPolicy,
|
|
||||||
pub prefill_servers: Vec<EngineInfo>,
|
|
||||||
pub decode_servers: Vec<EngineInfo>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl StrategyLB {
|
|
||||||
pub fn new(
|
|
||||||
policy: LBPolicy,
|
|
||||||
prefill_servers: Vec<EngineInfo>,
|
|
||||||
decode_servers: Vec<EngineInfo>,
|
|
||||||
) -> Self {
|
|
||||||
StrategyLB {
|
|
||||||
policy,
|
|
||||||
prefill_servers,
|
|
||||||
decode_servers,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn get_one_server(&self) -> EngineInfo {
|
|
||||||
assert!(!self.prefill_servers.is_empty());
|
|
||||||
assert!(!self.decode_servers.is_empty());
|
|
||||||
self.prefill_servers[0].clone()
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn get_all_servers(&self) -> Vec<EngineInfo> {
|
|
||||||
let mut all_servers = Vec::new();
|
|
||||||
all_servers.extend(self.prefill_servers.clone());
|
|
||||||
all_servers.extend(self.decode_servers.clone());
|
|
||||||
all_servers
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn select_pair(&self, client: &reqwest::Client) -> (EngineInfo, EngineInfo) {
|
|
||||||
match self.policy {
|
|
||||||
LBPolicy::Random => self.select_pd_pair_random(),
|
|
||||||
LBPolicy::PowerOfTwo => self.select_pd_pair_po2(client).await,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn select_pd_pair_random(&self) -> (EngineInfo, EngineInfo) {
|
|
||||||
let mut rng = rand::rng();
|
|
||||||
let prefill_index = rng.random_range(0..self.prefill_servers.len());
|
|
||||||
let decode_index = rng.random_range(0..self.decode_servers.len());
|
|
||||||
|
|
||||||
(
|
|
||||||
self.prefill_servers[prefill_index].clone(),
|
|
||||||
self.decode_servers[decode_index].clone(),
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn get_load_from_engine(
|
|
||||||
&self,
|
|
||||||
client: &reqwest::Client,
|
|
||||||
engine_info: &EngineInfo,
|
|
||||||
) -> Option<isize> {
|
|
||||||
let url = engine_info.api_path("/get_load");
|
|
||||||
let response = client.get(url).send().await.unwrap();
|
|
||||||
match response.status() {
|
|
||||||
reqwest::StatusCode::OK => {
|
|
||||||
let data = response.json::<serde_json::Value>().await.unwrap();
|
|
||||||
Some(data["load"].as_i64().unwrap() as isize)
|
|
||||||
}
|
|
||||||
_ => None,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn select_pd_pair_po2(&self, client: &reqwest::Client) -> (EngineInfo, EngineInfo) {
|
|
||||||
let mut rng = rand::rng();
|
|
||||||
let prefill1 =
|
|
||||||
self.prefill_servers[rng.random_range(0..self.prefill_servers.len())].clone();
|
|
||||||
let prefill2 =
|
|
||||||
self.prefill_servers[rng.random_range(0..self.prefill_servers.len())].clone();
|
|
||||||
let decode1 = self.decode_servers[rng.random_range(0..self.decode_servers.len())].clone();
|
|
||||||
let decode2 = self.decode_servers[rng.random_range(0..self.decode_servers.len())].clone();
|
|
||||||
let prefill1_load = self.get_load_from_engine(client, &prefill1).await;
|
|
||||||
let prefill2_load = self.get_load_from_engine(client, &prefill2).await;
|
|
||||||
let decode1_load = self.get_load_from_engine(client, &decode1).await;
|
|
||||||
let decode2_load = self.get_load_from_engine(client, &decode2).await;
|
|
||||||
|
|
||||||
(
|
|
||||||
if prefill1_load < prefill2_load {
|
|
||||||
prefill1
|
|
||||||
} else {
|
|
||||||
prefill2
|
|
||||||
},
|
|
||||||
if decode1_load < decode2_load {
|
|
||||||
decode1
|
|
||||||
} else {
|
|
||||||
decode2
|
|
||||||
},
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Reference in New Issue
Block a user