[router] add py binding and readme for openai router and history backend (#11453)

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
Keyang Ru
2025-10-14 09:42:34 -07:00
committed by GitHub
parent 5ea96ac7cc
commit eb8cac6fe2
8 changed files with 488 additions and 25 deletions

View File

@@ -30,6 +30,113 @@ pub enum PolicyType {
PowerOfTwo,
}
#[pyclass(eq)]
#[derive(Clone, PartialEq, Debug)]
pub enum BackendType {
Sglang,
Openai,
}
#[pyclass(eq)]
#[derive(Clone, PartialEq, Debug)]
pub enum HistoryBackendType {
Memory,
None,
Oracle,
}
#[pyclass]
#[derive(Clone, PartialEq)]
pub struct PyOracleConfig {
#[pyo3(get, set)]
pub wallet_path: Option<String>,
#[pyo3(get, set)]
pub connect_descriptor: Option<String>,
#[pyo3(get, set)]
pub username: Option<String>,
#[pyo3(get, set)]
pub password: Option<String>,
#[pyo3(get, set)]
pub pool_min: usize,
#[pyo3(get, set)]
pub pool_max: usize,
#[pyo3(get, set)]
pub pool_timeout_secs: u64,
}
impl std::fmt::Debug for PyOracleConfig {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("PyOracleConfig")
.field("wallet_path", &self.wallet_path)
.field("connect_descriptor", &"<redacted>")
.field("username", &self.username)
.field("password", &"<redacted>")
.field("pool_min", &self.pool_min)
.field("pool_max", &self.pool_max)
.field("pool_timeout_secs", &self.pool_timeout_secs)
.finish()
}
}
#[pymethods]
impl PyOracleConfig {
#[new]
#[pyo3(signature = (
password = None,
username = None,
connect_descriptor = None,
wallet_path = None,
pool_min = 1,
pool_max = 16,
pool_timeout_secs = 30,
))]
fn new(
password: Option<String>,
username: Option<String>,
connect_descriptor: Option<String>,
wallet_path: Option<String>,
pool_min: usize,
pool_max: usize,
pool_timeout_secs: u64,
) -> PyResult<Self> {
if pool_min == 0 {
return Err(pyo3::exceptions::PyValueError::new_err(
"pool_min must be at least 1",
));
}
if pool_max < pool_min {
return Err(pyo3::exceptions::PyValueError::new_err(
"pool_max must be >= pool_min",
));
}
Ok(PyOracleConfig {
wallet_path,
connect_descriptor,
username,
password,
pool_min,
pool_max,
pool_timeout_secs,
})
}
}
impl PyOracleConfig {
fn to_config_oracle(&self) -> config::OracleConfig {
// Simple conversion - validation happens later in validate_oracle()
config::OracleConfig {
wallet_path: self.wallet_path.clone(),
connect_descriptor: self.connect_descriptor.clone().unwrap_or_default(),
username: self.username.clone().unwrap_or_default(),
password: self.password.clone().unwrap_or_default(),
pool_min: self.pool_min,
pool_max: self.pool_max,
pool_timeout_secs: self.pool_timeout_secs,
}
}
}
#[pyclass]
#[derive(Debug, Clone, PartialEq)]
struct Router {
@@ -93,6 +200,9 @@ struct Router {
chat_template: Option<String>,
reasoning_parser: Option<String>,
tool_call_parser: Option<String>,
backend: BackendType,
history_backend: HistoryBackendType,
oracle_config: Option<PyOracleConfig>,
}
impl Router {
@@ -132,6 +242,10 @@ impl Router {
RoutingMode::Regular {
worker_urls: vec![],
}
} else if matches!(self.backend, BackendType::Openai) {
RoutingMode::OpenAI {
worker_urls: self.worker_urls.clone(),
}
} else if self.pd_disaggregation {
RoutingMode::PrefillDecode {
prefill_urls: self.prefill_urls.clone().unwrap_or_default(),
@@ -170,6 +284,20 @@ impl Router {
_ => None,
};
let history_backend = match self.history_backend {
HistoryBackendType::Memory => config::HistoryBackend::Memory,
HistoryBackendType::None => config::HistoryBackend::None,
HistoryBackendType::Oracle => config::HistoryBackend::Oracle,
};
let oracle = if matches!(self.history_backend, HistoryBackendType::Oracle) {
self.oracle_config
.as_ref()
.map(|cfg| cfg.to_config_oracle())
} else {
None
};
Ok(config::RouterConfig {
mode,
policy,
@@ -218,8 +346,8 @@ impl Router {
model_path: self.model_path.clone(),
tokenizer_path: self.tokenizer_path.clone(),
chat_template: self.chat_template.clone(),
history_backend: config::HistoryBackend::Memory,
oracle: None,
history_backend,
oracle,
reasoning_parser: self.reasoning_parser.clone(),
tool_call_parser: self.tool_call_parser.clone(),
})
@@ -289,6 +417,9 @@ impl Router {
chat_template = None,
reasoning_parser = None,
tool_call_parser = None,
backend = BackendType::Sglang,
history_backend = HistoryBackendType::Memory,
oracle_config = None,
))]
#[allow(clippy::too_many_arguments)]
fn new(
@@ -351,6 +482,9 @@ impl Router {
chat_template: Option<String>,
reasoning_parser: Option<String>,
tool_call_parser: Option<String>,
backend: BackendType,
history_backend: HistoryBackendType,
oracle_config: Option<PyOracleConfig>,
) -> PyResult<Self> {
let mut all_urls = worker_urls.clone();
@@ -427,6 +561,9 @@ impl Router {
chat_template,
reasoning_parser,
tool_call_parser,
backend,
history_backend,
oracle_config,
})
}
@@ -491,6 +628,9 @@ impl Router {
#[pymodule]
fn sglang_router_rs(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<PolicyType>()?;
m.add_class::<BackendType>()?;
m.add_class::<HistoryBackendType>()?;
m.add_class::<PyOracleConfig>()?;
m.add_class::<Router>()?;
Ok(())
}