[rust] cache-aware DP - approx tree (#1934)

This commit is contained in:
Byron Hsu
2024-11-10 21:57:32 -08:00
committed by GitHub
parent 087ab83223
commit f9633fa9b9
13 changed files with 1472 additions and 177 deletions

View File

@@ -1,37 +1,86 @@
// Python Binding
use pyo3::prelude::*;
pub mod router;
mod server;
pub mod tree;
// Python binding
#[pyclass(eq)]
#[derive(Clone, PartialEq)]
pub enum PolicyType {
Random,
RoundRobin,
ApproxTree,
}
#[pyclass]
struct Router {
host: String,
port: u16,
worker_urls: Vec<String>,
policy: String,
policy: PolicyType,
tokenizer_path: Option<String>,
cache_threshold: Option<f32>,
}
#[pymethods]
impl Router {
#[new]
fn new(host: String, port: u16, worker_urls: Vec<String>, policy: String) -> Self {
Router {
#[pyo3(signature = (
worker_urls,
policy = PolicyType::RoundRobin,
host = String::from("127.0.0.1"),
port = 3001,
tokenizer_path = None,
cache_threshold = Some(0.50)
))]
fn new(
worker_urls: Vec<String>,
policy: PolicyType,
host: String,
port: u16,
tokenizer_path: Option<String>,
cache_threshold: Option<f32>,
) -> PyResult<Self> {
// Validate required parameters for approx_tree policy
if matches!(policy, PolicyType::ApproxTree) {
if tokenizer_path.is_none() {
return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
"tokenizer_path is required for approx_tree policy",
));
}
}
Ok(Router {
host,
port,
worker_urls,
policy,
}
tokenizer_path,
cache_threshold,
})
}
fn start(&self) -> PyResult<()> {
let host = self.host.clone();
let port = self.port;
let worker_urls = self.worker_urls.clone();
let policy = self.policy.clone();
let policy_config = match &self.policy {
PolicyType::Random => router::PolicyConfig::RandomConfig,
PolicyType::RoundRobin => router::PolicyConfig::RoundRobinConfig,
PolicyType::ApproxTree => router::PolicyConfig::ApproxTreeConfig {
tokenizer_path: self
.tokenizer_path
.clone()
.expect("tokenizer_path is required for approx_tree policy"),
cache_threshold: self
.cache_threshold
.expect("cache_threshold is required for approx_tree policy"),
},
};
actix_web::rt::System::new().block_on(async move {
server::startup(host, port, worker_urls, policy)
server::startup(host, port, worker_urls, policy_config)
.await
.unwrap();
});
@@ -40,9 +89,9 @@ impl Router {
}
}
// python usage: `from sglang_router import Router`
#[pymodule]
fn sglang_router(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<PolicyType>()?;
m.add_class::<Router>()?;
Ok(())
}