Files
sglang/rust/src/lib.rs
2024-11-11 12:19:32 -08:00

98 lines
2.5 KiB
Rust

// Python Binding
use pyo3::prelude::*;
pub mod router;
mod server;
pub mod tree;
#[pyclass(eq)]
#[derive(Clone, PartialEq)]
pub enum PolicyType {
Random,
RoundRobin,
ApproxTree,
}
#[pyclass]
struct Router {
host: String,
port: u16,
worker_urls: Vec<String>,
policy: PolicyType,
tokenizer_path: Option<String>,
cache_threshold: Option<f32>,
}
#[pymethods]
impl Router {
#[new]
#[pyo3(signature = (
worker_urls,
policy = PolicyType::RoundRobin,
host = String::from("127.0.0.1"),
port = 3001,
tokenizer_path = None,
cache_threshold = Some(0.50)
))]
fn new(
worker_urls: Vec<String>,
policy: PolicyType,
host: String,
port: u16,
tokenizer_path: Option<String>,
cache_threshold: Option<f32>,
) -> PyResult<Self> {
// Validate required parameters for approx_tree policy
if matches!(policy, PolicyType::ApproxTree) {
if tokenizer_path.is_none() {
return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
"tokenizer_path is required for approx_tree policy",
));
}
}
Ok(Router {
host,
port,
worker_urls,
policy,
tokenizer_path,
cache_threshold,
})
}
fn start(&self) -> PyResult<()> {
let host = self.host.clone();
let port = self.port;
let worker_urls = self.worker_urls.clone();
let policy_config = match &self.policy {
PolicyType::Random => router::PolicyConfig::RandomConfig,
PolicyType::RoundRobin => router::PolicyConfig::RoundRobinConfig,
PolicyType::ApproxTree => router::PolicyConfig::ApproxTreeConfig {
tokenizer_path: self
.tokenizer_path
.clone()
.expect("tokenizer_path is required for approx_tree policy"),
cache_threshold: self
.cache_threshold
.expect("cache_threshold is required for approx_tree policy"),
},
};
actix_web::rt::System::new().block_on(async move {
server::startup(host, port, worker_urls, policy_config)
.await
.unwrap();
});
Ok(())
}
}
#[pymodule]
fn sglang_router_rs(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<PolicyType>()?;
m.add_class::<Router>()?;
Ok(())
}