[misc] remove pdlb rust (#7796)
This commit is contained in:
@@ -1,2 +0,0 @@
|
||||
reorder_imports = true
|
||||
reorder_modules = true
|
||||
@@ -1,28 +0,0 @@
|
||||
[package]
|
||||
edition = "2024"
|
||||
name = "sgl-pdlb"
|
||||
version = "0.1.0"
|
||||
|
||||
[lib]
|
||||
crate-type = ["cdylib", "rlib"]
|
||||
name = "sgl_pdlb_rs"
|
||||
|
||||
[dependencies]
|
||||
actix-web = "4.11"
|
||||
bytes = "1.8.0"
|
||||
chrono = "0.4.38"
|
||||
clap = { version = "4.4", features = ["derive"] }
|
||||
dashmap = "6.1.0"
|
||||
env_logger = "0.11.5"
|
||||
futures = "0.3"
|
||||
futures-util = "0.3"
|
||||
http = "1.3.1"
|
||||
log = "0.4.22"
|
||||
pyo3 = { version = "0.25.0", features = ["extension-module"] }
|
||||
rand = "0.9.0"
|
||||
reqwest = { version = "0.12.8", features = ["stream", "blocking", "json"] }
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
serde_json = "1.0"
|
||||
tokio = { version = "1.34", features = ["full"] }
|
||||
anyhow = "1.0.98"
|
||||
typetag = "0.2.20"
|
||||
@@ -1,12 +0,0 @@
|
||||
### Install dependencies
|
||||
|
||||
```bash
|
||||
pip install "maturin[patchelf]"
|
||||
```
|
||||
|
||||
### Build and install
|
||||
|
||||
```bash
|
||||
maturin develop
|
||||
pip install -e .
|
||||
```
|
||||
@@ -1 +0,0 @@
|
||||
__version__ = "0.0.1"
|
||||
@@ -1,14 +0,0 @@
|
||||
[build-system]
|
||||
requires = ["maturin>=1.8.0"]
|
||||
build-backend = "maturin"
|
||||
|
||||
[project]
|
||||
name = "sgl_pdlb"
|
||||
version = "0.0.1"
|
||||
|
||||
[tool.maturin]
|
||||
python-source = "py_src"
|
||||
module-name = "sgl_pdlb._rust"
|
||||
|
||||
[tool.maturin.build-backend]
|
||||
features = ["pyo3/extension-module"]
|
||||
@@ -1,133 +0,0 @@
|
||||
use crate::strategy_lb::EngineInfo;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum SingleOrBatch<T> {
|
||||
Single(T),
|
||||
Batch(Vec<T>),
|
||||
}
|
||||
|
||||
pub type InputIds = SingleOrBatch<Vec<i32>>;
|
||||
pub type InputText = SingleOrBatch<String>;
|
||||
pub type BootstrapHost = SingleOrBatch<String>;
|
||||
pub type BootstrapPort = SingleOrBatch<Option<u16>>;
|
||||
pub type BootstrapRoom = SingleOrBatch<u64>;
|
||||
|
||||
#[typetag::serde(tag = "type")]
|
||||
pub trait Bootstrap {
|
||||
fn is_stream(&self) -> bool;
|
||||
fn get_batch_size(&self) -> Result<Option<usize>, actix_web::Error>;
|
||||
fn set_bootstrap_info(
|
||||
&mut self,
|
||||
bootstrap_host: BootstrapHost,
|
||||
bootstrap_port: BootstrapPort,
|
||||
bootstrap_room: BootstrapRoom,
|
||||
);
|
||||
|
||||
fn add_bootstrap_info(&mut self, prefill_info: &EngineInfo) -> Result<(), actix_web::Error> {
|
||||
let batch_size = self.get_batch_size()?;
|
||||
if let Some(batch_size) = batch_size {
|
||||
self.set_bootstrap_info(
|
||||
BootstrapHost::Batch(vec![prefill_info.get_hostname(); batch_size]),
|
||||
BootstrapPort::Batch(vec![prefill_info.bootstrap_port; batch_size]),
|
||||
BootstrapRoom::Batch((0..batch_size).map(|_| rand::random::<u64>()).collect()),
|
||||
);
|
||||
} else {
|
||||
self.set_bootstrap_info(
|
||||
BootstrapHost::Single(prefill_info.get_hostname()),
|
||||
BootstrapPort::Single(prefill_info.bootstrap_port),
|
||||
BootstrapRoom::Single(rand::random::<u64>()),
|
||||
);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize)]
|
||||
pub struct GenerateReqInput {
|
||||
pub text: Option<InputText>,
|
||||
pub input_ids: Option<InputIds>,
|
||||
#[serde(default)]
|
||||
pub stream: bool,
|
||||
pub bootstrap_host: Option<BootstrapHost>,
|
||||
pub bootstrap_port: Option<BootstrapPort>,
|
||||
pub bootstrap_room: Option<BootstrapRoom>,
|
||||
|
||||
#[serde(flatten)]
|
||||
pub other: Value,
|
||||
}
|
||||
|
||||
impl GenerateReqInput {
|
||||
pub fn get_batch_size(&self) -> Result<Option<usize>, actix_web::Error> {
|
||||
if self.text.is_some() && self.input_ids.is_some() {
|
||||
return Err(actix_web::error::ErrorBadRequest(
|
||||
"Both text and input_ids are present in the request".to_string(),
|
||||
));
|
||||
}
|
||||
if let Some(InputText::Batch(texts)) = &self.text {
|
||||
return Ok(Some(texts.len()));
|
||||
}
|
||||
if let Some(InputIds::Batch(ids)) = &self.input_ids {
|
||||
return Ok(Some(ids.len()));
|
||||
}
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
|
||||
#[typetag::serde]
|
||||
impl Bootstrap for GenerateReqInput {
|
||||
fn is_stream(&self) -> bool {
|
||||
self.stream
|
||||
}
|
||||
|
||||
fn get_batch_size(&self) -> Result<Option<usize>, actix_web::Error> {
|
||||
self.get_batch_size()
|
||||
}
|
||||
|
||||
fn set_bootstrap_info(
|
||||
&mut self,
|
||||
bootstrap_host: BootstrapHost,
|
||||
bootstrap_port: BootstrapPort,
|
||||
bootstrap_room: BootstrapRoom,
|
||||
) {
|
||||
self.bootstrap_host = Some(bootstrap_host);
|
||||
self.bootstrap_port = Some(bootstrap_port);
|
||||
self.bootstrap_room = Some(bootstrap_room);
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize)]
|
||||
pub struct ChatReqInput {
|
||||
#[serde(default)]
|
||||
pub stream: bool,
|
||||
pub bootstrap_host: Option<BootstrapHost>,
|
||||
pub bootstrap_port: Option<BootstrapPort>,
|
||||
pub bootstrap_room: Option<BootstrapRoom>,
|
||||
|
||||
#[serde(flatten)]
|
||||
pub other: Value,
|
||||
}
|
||||
|
||||
#[typetag::serde]
|
||||
impl Bootstrap for ChatReqInput {
|
||||
fn is_stream(&self) -> bool {
|
||||
self.stream
|
||||
}
|
||||
|
||||
fn get_batch_size(&self) -> Result<Option<usize>, actix_web::Error> {
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
fn set_bootstrap_info(
|
||||
&mut self,
|
||||
bootstrap_host: BootstrapHost,
|
||||
bootstrap_port: BootstrapPort,
|
||||
bootstrap_room: BootstrapRoom,
|
||||
) {
|
||||
self.bootstrap_host = Some(bootstrap_host);
|
||||
self.bootstrap_port = Some(bootstrap_port);
|
||||
self.bootstrap_room = Some(bootstrap_room);
|
||||
}
|
||||
}
|
||||
@@ -1,175 +0,0 @@
|
||||
use crate::io_struct::Bootstrap;
|
||||
use crate::strategy_lb::{EngineInfo, EngineLoad, EngineType, LBPolicy, StrategyLB};
|
||||
use actix_web::HttpResponse;
|
||||
use bytes::Bytes;
|
||||
use futures::{Stream, StreamExt, future::join_all};
|
||||
use reqwest::{Method, StatusCode};
|
||||
use std::pin::Pin;
|
||||
|
||||
pub enum ProxyResponseBody {
|
||||
Full(Bytes),
|
||||
Stream(Pin<Box<dyn Stream<Item = Result<Bytes, actix_web::Error>> + Send>>),
|
||||
}
|
||||
|
||||
pub struct ProxyResponse {
|
||||
pub status: StatusCode,
|
||||
pub body: ProxyResponseBody,
|
||||
}
|
||||
|
||||
impl ProxyResponse {
|
||||
pub fn to_json(&self) -> Result<serde_json::Value, actix_web::Error> {
|
||||
match &self.body {
|
||||
ProxyResponseBody::Full(body) => Ok(serde_json::from_slice(&body)?),
|
||||
ProxyResponseBody::Stream(_) => Err(actix_web::error::ErrorBadRequest(
|
||||
"Stream response is not supported",
|
||||
)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Into<Result<HttpResponse, actix_web::Error>> for ProxyResponse {
|
||||
fn into(self) -> Result<HttpResponse, actix_web::Error> {
|
||||
let status = actix_web::http::StatusCode::from_u16(self.status.as_u16()).map_err(|e| {
|
||||
actix_web::error::ErrorBadGateway(format!("Invalid status code: {}", e))
|
||||
})?;
|
||||
match self.body {
|
||||
ProxyResponseBody::Full(body) => Ok(HttpResponse::Ok().status(status).body(body)),
|
||||
ProxyResponseBody::Stream(body) => Ok(HttpResponse::Ok()
|
||||
.status(status)
|
||||
.content_type("application/octet-stream")
|
||||
.streaming(body)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct LBConfig {
|
||||
pub host: String,
|
||||
pub port: u16,
|
||||
pub policy: String,
|
||||
pub prefill_infos: Vec<(String, Option<u16>)>,
|
||||
pub decode_infos: Vec<String>,
|
||||
pub log_interval: u64,
|
||||
pub timeout: u64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct LBState {
|
||||
pub strategy_lb: StrategyLB,
|
||||
pub client: reqwest::Client,
|
||||
pub log_interval: u64,
|
||||
}
|
||||
|
||||
impl LBState {
|
||||
pub fn new(lb_config: LBConfig) -> anyhow::Result<Self> {
|
||||
let client = reqwest::Client::builder()
|
||||
.timeout(std::time::Duration::from_secs(lb_config.timeout))
|
||||
.build()?;
|
||||
let policy = match lb_config.policy.as_str() {
|
||||
"random" => LBPolicy::Random,
|
||||
"po2" => LBPolicy::PowerOfTwo,
|
||||
_ => anyhow::bail!("Invalid policy"),
|
||||
};
|
||||
let prefill_servers = lb_config
|
||||
.prefill_infos
|
||||
.into_iter()
|
||||
.map(|(url, port)| EngineInfo::new_prefill(url, port))
|
||||
.collect();
|
||||
let decode_servers = lb_config
|
||||
.decode_infos
|
||||
.into_iter()
|
||||
.map(|url| EngineInfo::new_decode(url))
|
||||
.collect();
|
||||
let lb = StrategyLB::new(policy, prefill_servers, decode_servers);
|
||||
Ok(Self {
|
||||
strategy_lb: lb,
|
||||
client,
|
||||
log_interval: lb_config.log_interval,
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn route_one(
|
||||
&self,
|
||||
engine_info: &EngineInfo,
|
||||
method: Method,
|
||||
api_path: &str,
|
||||
request: Option<&serde_json::Value>,
|
||||
stream: bool,
|
||||
) -> Result<ProxyResponse, actix_web::Error> {
|
||||
let url = engine_info.api_path(api_path);
|
||||
let request = request.unwrap_or(&serde_json::Value::Null);
|
||||
let task = self.client.request(method, url).json(request).send();
|
||||
let resp = task.await.map_err(actix_web::error::ErrorBadGateway)?;
|
||||
// FIXME: handle error status code (map status code to error)
|
||||
let status = resp.status();
|
||||
let body = if stream {
|
||||
let resp_stream = resp.bytes_stream().map(|r| {
|
||||
r.map_err(actix_web::error::ErrorBadGateway)
|
||||
.map(Bytes::from)
|
||||
});
|
||||
ProxyResponseBody::Stream(Box::pin(resp_stream))
|
||||
} else {
|
||||
let body = resp
|
||||
.bytes()
|
||||
.await
|
||||
.map_err(actix_web::error::ErrorBadGateway)?;
|
||||
ProxyResponseBody::Full(body)
|
||||
};
|
||||
Ok(ProxyResponse { status, body })
|
||||
}
|
||||
|
||||
pub async fn route_collect(
|
||||
&self,
|
||||
engines: &Vec<EngineInfo>,
|
||||
method: Method,
|
||||
api_path: &str,
|
||||
request: Option<&serde_json::Value>,
|
||||
) -> Result<Vec<ProxyResponse>, actix_web::Error> {
|
||||
let tasks = engines
|
||||
.iter()
|
||||
.map(|engine| self.route_one(engine, method.clone(), api_path, request, false));
|
||||
let responses = join_all(tasks).await;
|
||||
responses
|
||||
.into_iter()
|
||||
.map(|r| r.map_err(actix_web::error::ErrorBadGateway))
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub async fn generate(
|
||||
&self,
|
||||
api_path: &str,
|
||||
mut req: Box<dyn Bootstrap>,
|
||||
) -> Result<HttpResponse, actix_web::Error> {
|
||||
let (prefill, decode) = self.strategy_lb.select_pair(&self.client).await;
|
||||
let stream = req.is_stream();
|
||||
req.add_bootstrap_info(&prefill)?;
|
||||
let json = serde_json::to_value(req)?;
|
||||
let prefill_task = self.route_one(&prefill, Method::POST, api_path, Some(&json), false);
|
||||
let decode_task = self.route_one(&decode, Method::POST, api_path, Some(&json), stream);
|
||||
let (_, decode_response) = tokio::join!(prefill_task, decode_task);
|
||||
decode_response?.into()
|
||||
}
|
||||
|
||||
pub async fn get_engine_loads(
|
||||
&self,
|
||||
) -> Result<(Vec<EngineLoad>, Vec<EngineLoad>), actix_web::Error> {
|
||||
let servers = self.strategy_lb.get_all_servers();
|
||||
let responses = self
|
||||
.route_collect(&servers, Method::GET, "/get_load", None)
|
||||
.await?;
|
||||
let loads = responses
|
||||
.into_iter()
|
||||
.enumerate()
|
||||
.map(|(i, r)| Ok(EngineLoad::from_json(&servers[i], &r.to_json()?)))
|
||||
.collect::<Result<Vec<EngineLoad>, actix_web::Error>>()?;
|
||||
let mut prefill_loads = Vec::new();
|
||||
let mut decode_loads = Vec::new();
|
||||
for load in loads {
|
||||
match load.engine_info.engine_type {
|
||||
EngineType::Prefill => prefill_loads.push(load),
|
||||
EngineType::Decode => decode_loads.push(load),
|
||||
}
|
||||
}
|
||||
Ok((prefill_loads, decode_loads))
|
||||
}
|
||||
}
|
||||
@@ -1,68 +0,0 @@
|
||||
pub mod io_struct;
|
||||
pub mod lb_state;
|
||||
pub mod server;
|
||||
pub mod strategy_lb;
|
||||
use pyo3::{exceptions::PyRuntimeError, prelude::*};
|
||||
|
||||
use lb_state::{LBConfig, LBState};
|
||||
use server::{periodic_logging, startup};
|
||||
use tokio::signal;
|
||||
|
||||
#[pyclass]
|
||||
pub struct LoadBalancer {
|
||||
lb_config: LBConfig,
|
||||
}
|
||||
|
||||
#[pymethods]
|
||||
impl LoadBalancer {
|
||||
#[new]
|
||||
pub fn new(
|
||||
host: String,
|
||||
port: u16,
|
||||
policy: String,
|
||||
prefill_infos: Vec<(String, Option<u16>)>,
|
||||
decode_infos: Vec<String>,
|
||||
log_interval: u64,
|
||||
timeout: u64,
|
||||
) -> PyResult<Self> {
|
||||
let lb_config = LBConfig {
|
||||
host,
|
||||
port,
|
||||
policy,
|
||||
prefill_infos,
|
||||
decode_infos,
|
||||
log_interval,
|
||||
timeout,
|
||||
};
|
||||
Ok(LoadBalancer { lb_config })
|
||||
}
|
||||
|
||||
pub fn start(&self) -> PyResult<()> {
|
||||
let lb_state = LBState::new(self.lb_config.clone()).map_err(|e| {
|
||||
PyRuntimeError::new_err(format!("Failed to build load balancer: {}", e))
|
||||
})?;
|
||||
|
||||
let ret: PyResult<()> = actix_web::rt::System::new().block_on(async move {
|
||||
tokio::select! {
|
||||
_ = periodic_logging(lb_state.clone()) => {
|
||||
unreachable!()
|
||||
}
|
||||
res = startup(self.lb_config.clone(), lb_state) => {
|
||||
res.map_err(|e| PyRuntimeError::new_err(e.to_string()))?;
|
||||
unreachable!()
|
||||
}
|
||||
_ = signal::ctrl_c() => {
|
||||
println!("Received Ctrl+C, shutting down");
|
||||
std::process::exit(0);
|
||||
}
|
||||
}
|
||||
});
|
||||
ret
|
||||
}
|
||||
}
|
||||
|
||||
#[pymodule]
|
||||
fn _rust(_py: Python, m: &Bound<PyModule>) -> PyResult<()> {
|
||||
m.add_class::<LoadBalancer>()?;
|
||||
Ok(())
|
||||
}
|
||||
@@ -1,46 +0,0 @@
|
||||
mod io_struct;
|
||||
mod lb_state;
|
||||
mod server;
|
||||
mod strategy_lb;
|
||||
|
||||
use lb_state::{LBConfig, LBState};
|
||||
use server::{periodic_logging, startup};
|
||||
use tokio::signal;
|
||||
|
||||
fn main() -> anyhow::Result<()> {
|
||||
// FIXME: test code, move to test folder
|
||||
let prefill_infos = (0..8)
|
||||
.map(|i| (format!("123.123.123.123:{}", i), None))
|
||||
.collect::<Vec<(String, Option<u16>)>>();
|
||||
|
||||
let decode_infos = (0..32)
|
||||
.map(|i| format!("233.233.233.233:{}", i))
|
||||
.collect::<Vec<String>>();
|
||||
|
||||
let lb_config = LBConfig {
|
||||
host: "localhost".to_string(),
|
||||
port: 8080,
|
||||
policy: "random".to_string(),
|
||||
prefill_infos,
|
||||
decode_infos,
|
||||
log_interval: 5,
|
||||
timeout: 600,
|
||||
};
|
||||
let lb_state = LBState::new(lb_config.clone()).map_err(|e| anyhow::anyhow!(e))?;
|
||||
let ret: anyhow::Result<()> = actix_web::rt::System::new().block_on(async move {
|
||||
tokio::select! {
|
||||
_ = periodic_logging(lb_state.clone()) => {
|
||||
unreachable!()
|
||||
}
|
||||
res = startup(lb_config.clone(), lb_state) => {
|
||||
res.map_err(|e| anyhow::anyhow!(e))?;
|
||||
unreachable!()
|
||||
}
|
||||
_ = signal::ctrl_c() => {
|
||||
println!("Received Ctrl+C, shutting down");
|
||||
std::process::exit(0);
|
||||
}
|
||||
}
|
||||
});
|
||||
ret
|
||||
}
|
||||
@@ -1,183 +0,0 @@
|
||||
use crate::io_struct::{ChatReqInput, GenerateReqInput};
|
||||
use crate::lb_state::{LBConfig, LBState};
|
||||
use crate::strategy_lb::EngineType;
|
||||
use actix_web::{HttpRequest, HttpResponse, HttpServer, get, post, web};
|
||||
use reqwest::Method;
|
||||
use serde_json::json;
|
||||
use std::io::Write;
|
||||
|
||||
#[get("/health")]
|
||||
pub async fn health(_req: HttpRequest, _: web::Data<LBState>) -> HttpResponse {
|
||||
HttpResponse::Ok().body("Ok")
|
||||
}
|
||||
|
||||
#[get("/health_generate")]
|
||||
pub async fn health_generate(
|
||||
_req: HttpRequest,
|
||||
app_state: web::Data<LBState>,
|
||||
) -> Result<HttpResponse, actix_web::Error> {
|
||||
let servers = app_state.strategy_lb.get_all_servers();
|
||||
app_state
|
||||
.route_collect(&servers, Method::GET, "/health_generate", None)
|
||||
.await?;
|
||||
// FIXME: log the response
|
||||
Ok(HttpResponse::Ok().body("Health check passed on all servers"))
|
||||
}
|
||||
|
||||
#[post("/flush_cache")]
|
||||
pub async fn flush_cache(
|
||||
_req: HttpRequest,
|
||||
app_state: web::Data<LBState>,
|
||||
) -> Result<HttpResponse, actix_web::Error> {
|
||||
let servers = app_state.strategy_lb.get_all_servers();
|
||||
app_state
|
||||
.route_collect(&servers, Method::POST, "/flush_cache", None)
|
||||
.await?;
|
||||
Ok(HttpResponse::Ok().body("Cache flushed on all servers"))
|
||||
}
|
||||
|
||||
#[get("/get_model_info")]
|
||||
pub async fn get_model_info(
|
||||
_req: HttpRequest,
|
||||
app_state: web::Data<LBState>,
|
||||
) -> Result<HttpResponse, actix_web::Error> {
|
||||
// Return the first server's model info
|
||||
let engine = app_state.strategy_lb.get_one_server();
|
||||
app_state
|
||||
.route_one(&engine, Method::GET, "/get_model_info", None, false)
|
||||
.await?
|
||||
.into()
|
||||
}
|
||||
|
||||
#[post("/generate")]
|
||||
pub async fn generate(
|
||||
_req: HttpRequest,
|
||||
req: web::Json<GenerateReqInput>,
|
||||
app_state: web::Data<LBState>,
|
||||
) -> Result<HttpResponse, actix_web::Error> {
|
||||
app_state
|
||||
.generate("/generate", Box::new(req.into_inner()))
|
||||
.await
|
||||
}
|
||||
|
||||
#[post("/v1/completions")]
|
||||
pub async fn completions(
|
||||
_req: HttpRequest,
|
||||
req: web::Json<GenerateReqInput>,
|
||||
app_state: web::Data<LBState>,
|
||||
) -> Result<HttpResponse, actix_web::Error> {
|
||||
app_state
|
||||
.generate("/v1/completions", Box::new(req.into_inner()))
|
||||
.await
|
||||
}
|
||||
|
||||
#[post("/v1/chat/completions")]
|
||||
pub async fn chat_completions(
|
||||
_req: HttpRequest,
|
||||
req: web::Json<ChatReqInput>,
|
||||
app_state: web::Data<LBState>,
|
||||
) -> Result<HttpResponse, actix_web::Error> {
|
||||
app_state
|
||||
.generate("/v1/chat/completions", Box::new(req.into_inner()))
|
||||
.await
|
||||
}
|
||||
|
||||
#[get("/get_server_info")]
|
||||
pub async fn get_server_info(
|
||||
_req: HttpRequest,
|
||||
app_state: web::Data<LBState>,
|
||||
) -> Result<HttpResponse, actix_web::Error> {
|
||||
let servers = app_state.strategy_lb.get_all_servers();
|
||||
let responses = app_state
|
||||
.route_collect(&servers, Method::GET, "/get_server_info", None)
|
||||
.await?;
|
||||
let mut prefill_infos = Vec::new();
|
||||
let mut decode_infos = Vec::new();
|
||||
for (i, resp) in responses.iter().enumerate() {
|
||||
let json = resp.to_json()?;
|
||||
match servers[i].engine_type {
|
||||
EngineType::Prefill => prefill_infos.push(json),
|
||||
EngineType::Decode => decode_infos.push(json),
|
||||
}
|
||||
}
|
||||
Ok(HttpResponse::Ok().json(json!({
|
||||
"prefill": prefill_infos,
|
||||
"decode": decode_infos,
|
||||
})))
|
||||
}
|
||||
|
||||
#[get("/get_loads")]
|
||||
pub async fn get_loads(
|
||||
_req: HttpRequest,
|
||||
app_state: web::Data<LBState>,
|
||||
) -> Result<HttpResponse, actix_web::Error> {
|
||||
let (prefill_loads, decode_loads) = app_state.get_engine_loads().await?;
|
||||
Ok(HttpResponse::Ok().json(json!({
|
||||
"prefill": prefill_loads.into_iter().map(|l| l.to_json()).collect::<Vec<_>>(),
|
||||
"decode": decode_loads.into_iter().map(|l| l.to_json()).collect::<Vec<_>>()
|
||||
})))
|
||||
}
|
||||
|
||||
pub async fn periodic_logging(lb_state: LBState) {
|
||||
// FIXME: currently we can just clone the lb_state to log as the lb is stateless
|
||||
loop {
|
||||
tokio::time::sleep(std::time::Duration::from_secs(lb_state.log_interval)).await;
|
||||
let (prefill_loads, decode_loads) = match lb_state.get_engine_loads().await {
|
||||
Ok((prefill_loads, decode_loads)) => (prefill_loads, decode_loads),
|
||||
Err(e) => {
|
||||
log::error!("Failed to get engine loads: {}", e);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
let prefill_loads = prefill_loads
|
||||
.into_iter()
|
||||
.map(|l| l.to_string())
|
||||
.collect::<Vec<_>>();
|
||||
let decode_loads = decode_loads
|
||||
.into_iter()
|
||||
.map(|l| l.to_string())
|
||||
.collect::<Vec<_>>();
|
||||
log::info!("Prefill loads: {}", prefill_loads.join(", "));
|
||||
log::info!("Decode loads: {}", decode_loads.join(", "));
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn startup(lb_config: LBConfig, lb_state: LBState) -> std::io::Result<()> {
|
||||
let app_state = web::Data::new(lb_state);
|
||||
|
||||
println!("Starting server at {}:{}", lb_config.host, lb_config.port);
|
||||
|
||||
// default level is info
|
||||
env_logger::Builder::new()
|
||||
.format(|buf, record| {
|
||||
writeln!(
|
||||
buf,
|
||||
"{} - {} - {}",
|
||||
chrono::Local::now().format("%Y-%m-%d %H:%M:%S"),
|
||||
record.level(),
|
||||
record.args()
|
||||
)
|
||||
})
|
||||
.filter(None, log::LevelFilter::Info)
|
||||
.init();
|
||||
|
||||
HttpServer::new(move || {
|
||||
actix_web::App::new()
|
||||
.wrap(actix_web::middleware::Logger::default())
|
||||
.app_data(app_state.clone())
|
||||
.service(health)
|
||||
.service(health_generate)
|
||||
.service(flush_cache)
|
||||
.service(get_model_info)
|
||||
.service(get_server_info)
|
||||
.service(get_loads)
|
||||
.service(generate)
|
||||
.service(chat_completions)
|
||||
.service(completions)
|
||||
})
|
||||
.bind((lb_config.host, lb_config.port))?
|
||||
.run()
|
||||
.await?;
|
||||
|
||||
std::io::Result::Ok(())
|
||||
}
|
||||
@@ -1,182 +0,0 @@
|
||||
use rand::Rng;
|
||||
use serde_json::json;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum EngineType {
|
||||
Prefill,
|
||||
Decode,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct EngineInfo {
|
||||
pub engine_type: EngineType,
|
||||
pub url: String,
|
||||
pub bootstrap_port: Option<u16>,
|
||||
}
|
||||
|
||||
impl EngineInfo {
|
||||
pub fn new_prefill(url: String, bootstrap_port: Option<u16>) -> Self {
|
||||
EngineInfo {
|
||||
engine_type: EngineType::Prefill,
|
||||
url,
|
||||
bootstrap_port,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn new_decode(url: String) -> Self {
|
||||
EngineInfo {
|
||||
engine_type: EngineType::Decode,
|
||||
url,
|
||||
bootstrap_port: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn api_path(&self, api_path: &str) -> String {
|
||||
if api_path.starts_with("/") {
|
||||
format!("{}{}", self.url, api_path)
|
||||
} else {
|
||||
format!("{}/{}", self.url, api_path)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn to_string(&self) -> String {
|
||||
format!("({:?}@{})", self.engine_type, self.url)
|
||||
}
|
||||
|
||||
pub fn get_hostname(&self) -> String {
|
||||
let url = self
|
||||
.url
|
||||
.trim_start_matches("http://")
|
||||
.trim_start_matches("https://");
|
||||
url.split(':').next().unwrap().to_string()
|
||||
}
|
||||
}
|
||||
|
||||
pub struct EngineLoad {
|
||||
pub engine_info: EngineInfo,
|
||||
pub load: isize,
|
||||
}
|
||||
|
||||
impl EngineLoad {
|
||||
pub fn from_json(engine_info: &EngineInfo, json: &serde_json::Value) -> Self {
|
||||
let load = match json.get("load") {
|
||||
Some(load) => load.as_i64().unwrap_or(-1) as isize,
|
||||
None => -1,
|
||||
};
|
||||
EngineLoad {
|
||||
engine_info: engine_info.clone(),
|
||||
load,
|
||||
}
|
||||
}
|
||||
pub fn to_json(&self) -> serde_json::Value {
|
||||
json!({
|
||||
"engine": self.engine_info.to_string(),
|
||||
"load": self.load,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn to_string(&self) -> String {
|
||||
format!("{}: {}", self.engine_info.to_string(), self.load)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum LBPolicy {
|
||||
Random,
|
||||
PowerOfTwo,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct StrategyLB {
|
||||
pub policy: LBPolicy,
|
||||
pub prefill_servers: Vec<EngineInfo>,
|
||||
pub decode_servers: Vec<EngineInfo>,
|
||||
}
|
||||
|
||||
impl StrategyLB {
|
||||
pub fn new(
|
||||
policy: LBPolicy,
|
||||
prefill_servers: Vec<EngineInfo>,
|
||||
decode_servers: Vec<EngineInfo>,
|
||||
) -> Self {
|
||||
StrategyLB {
|
||||
policy,
|
||||
prefill_servers,
|
||||
decode_servers,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_one_server(&self) -> EngineInfo {
|
||||
assert!(!self.prefill_servers.is_empty());
|
||||
assert!(!self.decode_servers.is_empty());
|
||||
self.prefill_servers[0].clone()
|
||||
}
|
||||
|
||||
pub fn get_all_servers(&self) -> Vec<EngineInfo> {
|
||||
let mut all_servers = Vec::new();
|
||||
all_servers.extend(self.prefill_servers.clone());
|
||||
all_servers.extend(self.decode_servers.clone());
|
||||
all_servers
|
||||
}
|
||||
|
||||
pub async fn select_pair(&self, client: &reqwest::Client) -> (EngineInfo, EngineInfo) {
|
||||
match self.policy {
|
||||
LBPolicy::Random => self.select_pd_pair_random(),
|
||||
LBPolicy::PowerOfTwo => self.select_pd_pair_po2(client).await,
|
||||
}
|
||||
}
|
||||
|
||||
fn select_pd_pair_random(&self) -> (EngineInfo, EngineInfo) {
|
||||
let mut rng = rand::rng();
|
||||
let prefill_index = rng.random_range(0..self.prefill_servers.len());
|
||||
let decode_index = rng.random_range(0..self.decode_servers.len());
|
||||
|
||||
(
|
||||
self.prefill_servers[prefill_index].clone(),
|
||||
self.decode_servers[decode_index].clone(),
|
||||
)
|
||||
}
|
||||
|
||||
async fn get_load_from_engine(
|
||||
&self,
|
||||
client: &reqwest::Client,
|
||||
engine_info: &EngineInfo,
|
||||
) -> Option<isize> {
|
||||
let url = engine_info.api_path("/get_load");
|
||||
let response = client.get(url).send().await.unwrap();
|
||||
match response.status() {
|
||||
reqwest::StatusCode::OK => {
|
||||
let data = response.json::<serde_json::Value>().await.unwrap();
|
||||
Some(data["load"].as_i64().unwrap() as isize)
|
||||
}
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
async fn select_pd_pair_po2(&self, client: &reqwest::Client) -> (EngineInfo, EngineInfo) {
|
||||
let mut rng = rand::rng();
|
||||
let prefill1 =
|
||||
self.prefill_servers[rng.random_range(0..self.prefill_servers.len())].clone();
|
||||
let prefill2 =
|
||||
self.prefill_servers[rng.random_range(0..self.prefill_servers.len())].clone();
|
||||
let decode1 = self.decode_servers[rng.random_range(0..self.decode_servers.len())].clone();
|
||||
let decode2 = self.decode_servers[rng.random_range(0..self.decode_servers.len())].clone();
|
||||
let prefill1_load = self.get_load_from_engine(client, &prefill1).await;
|
||||
let prefill2_load = self.get_load_from_engine(client, &prefill2).await;
|
||||
let decode1_load = self.get_load_from_engine(client, &decode1).await;
|
||||
let decode2_load = self.get_load_from_engine(client, &decode2).await;
|
||||
|
||||
(
|
||||
if prefill1_load < prefill2_load {
|
||||
prefill1
|
||||
} else {
|
||||
prefill2
|
||||
},
|
||||
if decode1_load < decode2_load {
|
||||
decode1
|
||||
} else {
|
||||
decode2
|
||||
},
|
||||
)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user