sglangv0.5.2 & support Qwen3-Next-80B-A3B-Instruct

This commit is contained in:
maxiao1
2025-09-13 17:00:20 +08:00
commit 118f1fc726
2037 changed files with 515371 additions and 0 deletions

View File

@@ -0,0 +1,195 @@
//! Factory for creating router instances
use super::{
http::{openai_router::OpenAIRouter, pd_router::PDRouter, router::Router},
RouterTrait,
};
use crate::config::{ConnectionMode, PolicyConfig, RoutingMode};
use crate::policies::PolicyFactory;
use crate::server::AppContext;
use std::sync::Arc;
/// Factory for creating router instances based on configuration
pub struct RouterFactory;
impl RouterFactory {
/// Create a router instance from application context
pub async fn create_router(ctx: &Arc<AppContext>) -> Result<Box<dyn RouterTrait>, String> {
// Check if IGW mode is enabled
if ctx.router_config.enable_igw {
return Self::create_igw_router(ctx).await;
}
// Check connection mode and route to appropriate implementation
match ctx.router_config.connection_mode {
ConnectionMode::Grpc => {
// Route to gRPC implementation based on routing mode
match &ctx.router_config.mode {
RoutingMode::Regular { worker_urls } => {
Self::create_grpc_router(worker_urls, &ctx.router_config.policy, ctx).await
}
RoutingMode::PrefillDecode {
prefill_urls,
decode_urls,
prefill_policy,
decode_policy,
} => {
Self::create_grpc_pd_router(
prefill_urls,
decode_urls,
prefill_policy.as_ref(),
decode_policy.as_ref(),
&ctx.router_config.policy,
ctx,
)
.await
}
RoutingMode::OpenAI { .. } => {
Err("OpenAI mode requires HTTP connection_mode".to_string())
}
}
}
ConnectionMode::Http => {
// Route to HTTP implementation based on routing mode
match &ctx.router_config.mode {
RoutingMode::Regular { worker_urls } => {
Self::create_regular_router(worker_urls, &ctx.router_config.policy, ctx)
.await
}
RoutingMode::PrefillDecode {
prefill_urls,
decode_urls,
prefill_policy,
decode_policy,
} => {
Self::create_pd_router(
prefill_urls,
decode_urls,
prefill_policy.as_ref(),
decode_policy.as_ref(),
&ctx.router_config.policy,
ctx,
)
.await
}
RoutingMode::OpenAI { worker_urls, .. } => {
Self::create_openai_router(worker_urls.clone(), ctx).await
}
}
}
}
}
/// Create a regular router with injected policy
async fn create_regular_router(
worker_urls: &[String],
policy_config: &PolicyConfig,
ctx: &Arc<AppContext>,
) -> Result<Box<dyn RouterTrait>, String> {
// Create policy
let policy = PolicyFactory::create_from_config(policy_config);
// Create regular router with injected policy and context
let router = Router::new(worker_urls.to_vec(), policy, ctx).await?;
Ok(Box::new(router))
}
/// Create a PD router with injected policy
async fn create_pd_router(
prefill_urls: &[(String, Option<u16>)],
decode_urls: &[String],
prefill_policy_config: Option<&PolicyConfig>,
decode_policy_config: Option<&PolicyConfig>,
main_policy_config: &PolicyConfig,
ctx: &Arc<AppContext>,
) -> Result<Box<dyn RouterTrait>, String> {
// Create policies - use specific policies if provided, otherwise fall back to main policy
let prefill_policy =
PolicyFactory::create_from_config(prefill_policy_config.unwrap_or(main_policy_config));
let decode_policy =
PolicyFactory::create_from_config(decode_policy_config.unwrap_or(main_policy_config));
// Create PD router with separate policies and context
let router = PDRouter::new(
prefill_urls.to_vec(),
decode_urls.to_vec(),
prefill_policy,
decode_policy,
ctx,
)
.await?;
Ok(Box::new(router))
}
/// Create a gRPC router with injected policy
pub async fn create_grpc_router(
worker_urls: &[String],
policy_config: &PolicyConfig,
ctx: &Arc<AppContext>,
) -> Result<Box<dyn RouterTrait>, String> {
use super::grpc::router::GrpcRouter;
// Create policy
let policy = PolicyFactory::create_from_config(policy_config);
// Create gRPC router with context
let router = GrpcRouter::new(worker_urls.to_vec(), policy, ctx).await?;
Ok(Box::new(router))
}
/// Create a gRPC PD router with tokenizer and worker configuration
pub async fn create_grpc_pd_router(
prefill_urls: &[(String, Option<u16>)],
decode_urls: &[String],
prefill_policy_config: Option<&PolicyConfig>,
decode_policy_config: Option<&PolicyConfig>,
main_policy_config: &PolicyConfig,
ctx: &Arc<AppContext>,
) -> Result<Box<dyn RouterTrait>, String> {
use super::grpc::pd_router::GrpcPDRouter;
// Create policies - use specific policies if provided, otherwise fall back to main policy
let prefill_policy =
PolicyFactory::create_from_config(prefill_policy_config.unwrap_or(main_policy_config));
let decode_policy =
PolicyFactory::create_from_config(decode_policy_config.unwrap_or(main_policy_config));
// Create gRPC PD router with context
let router = GrpcPDRouter::new(
prefill_urls.to_vec(),
decode_urls.to_vec(),
prefill_policy,
decode_policy,
ctx,
)
.await?;
Ok(Box::new(router))
}
/// Create an OpenAI router
async fn create_openai_router(
worker_urls: Vec<String>,
ctx: &Arc<AppContext>,
) -> Result<Box<dyn RouterTrait>, String> {
// Use the first worker URL as the OpenAI-compatible base
let base_url = worker_urls
.first()
.cloned()
.ok_or_else(|| "OpenAI mode requires at least one worker URL".to_string())?;
let router =
OpenAIRouter::new(base_url, Some(ctx.router_config.circuit_breaker.clone())).await?;
Ok(Box::new(router))
}
/// Create an IGW router (placeholder for future implementation)
async fn create_igw_router(_ctx: &Arc<AppContext>) -> Result<Box<dyn RouterTrait>, String> {
// For now, return an error indicating IGW is not yet implemented
Err("IGW mode is not yet implemented".to_string())
}
}

View File

@@ -0,0 +1,4 @@
//! gRPC router implementations
pub mod pd_router;
pub mod router;

View File

@@ -0,0 +1,328 @@
// PD (Prefill-Decode) gRPC Router Implementation
use crate::config::types::RetryConfig;
use crate::core::{
BasicWorker, CircuitBreakerConfig, HealthChecker, HealthConfig, Worker, WorkerType,
};
use crate::grpc::SglangSchedulerClient;
use crate::metrics::RouterMetrics;
use crate::policies::LoadBalancingPolicy;
use crate::reasoning_parser::ParserFactory;
use crate::routers::{RouterTrait, WorkerManagement};
use crate::tokenizer::traits::Tokenizer;
use crate::tool_parser::ParserRegistry;
use async_trait::async_trait;
use axum::{
body::Body,
extract::Request,
http::{HeaderMap, StatusCode},
response::{IntoResponse, Response},
};
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
use std::time::Duration;
use tracing::{info, warn};
/// gRPC PD (Prefill-Decode) router implementation for SGLang
#[allow(dead_code)] // Fields will be used once implementation is complete
pub struct GrpcPDRouter {
/// Prefill worker connections
prefill_workers: Arc<RwLock<Vec<Box<dyn Worker>>>>,
/// Decode worker connections
decode_workers: Arc<RwLock<Vec<Box<dyn Worker>>>>,
/// gRPC clients for prefill workers
prefill_grpc_clients: Arc<RwLock<HashMap<String, SglangSchedulerClient>>>,
/// gRPC clients for decode workers
decode_grpc_clients: Arc<RwLock<HashMap<String, SglangSchedulerClient>>>,
/// Load balancing policy for prefill
prefill_policy: Arc<dyn LoadBalancingPolicy>,
/// Load balancing policy for decode
decode_policy: Arc<dyn LoadBalancingPolicy>,
/// Tokenizer for handling text encoding/decoding
tokenizer: Arc<dyn Tokenizer>,
/// Reasoning parser factory for structured reasoning outputs
reasoning_parser_factory: ParserFactory,
/// Tool parser registry for function/tool calls
tool_parser_registry: &'static ParserRegistry,
/// Worker health checkers
_prefill_health_checker: Option<HealthChecker>,
_decode_health_checker: Option<HealthChecker>,
/// Configuration
timeout_secs: u64,
interval_secs: u64,
dp_aware: bool,
api_key: Option<String>,
retry_config: RetryConfig,
circuit_breaker_config: CircuitBreakerConfig,
}
impl GrpcPDRouter {
/// Create a new gRPC PD router
pub async fn new(
prefill_urls: Vec<(String, Option<u16>)>,
decode_urls: Vec<String>,
prefill_policy: Arc<dyn LoadBalancingPolicy>,
decode_policy: Arc<dyn LoadBalancingPolicy>,
ctx: &Arc<crate::server::AppContext>,
) -> Result<Self, String> {
// Update metrics
RouterMetrics::set_active_workers(prefill_urls.len() + decode_urls.len());
// Extract necessary components from context
let tokenizer = ctx
.tokenizer
.as_ref()
.ok_or_else(|| "gRPC PD router requires tokenizer".to_string())?
.clone();
let reasoning_parser_factory = ctx
.reasoning_parser_factory
.as_ref()
.ok_or_else(|| "gRPC PD router requires reasoning parser factory".to_string())?
.clone();
let tool_parser_registry = ctx
.tool_parser_registry
.ok_or_else(|| "gRPC PD router requires tool parser registry".to_string())?;
// Convert config CircuitBreakerConfig to core CircuitBreakerConfig
let circuit_breaker_config = ctx.router_config.effective_circuit_breaker_config();
let core_cb_config = CircuitBreakerConfig {
failure_threshold: circuit_breaker_config.failure_threshold,
success_threshold: circuit_breaker_config.success_threshold,
timeout_duration: Duration::from_secs(circuit_breaker_config.timeout_duration_secs),
window_duration: Duration::from_secs(circuit_breaker_config.window_duration_secs),
};
// Create gRPC clients for prefill workers
let mut prefill_grpc_clients = HashMap::new();
for (url, _bootstrap_port) in &prefill_urls {
match SglangSchedulerClient::connect(url).await {
Ok(client) => {
prefill_grpc_clients.insert(url.clone(), client);
info!("Connected to gRPC prefill worker at {}", url);
}
Err(e) => {
warn!("Failed to connect to gRPC prefill worker at {}: {}", url, e);
// Continue with other workers
}
}
}
// Create gRPC clients for decode workers
let mut decode_grpc_clients = HashMap::new();
for url in &decode_urls {
match SglangSchedulerClient::connect(url).await {
Ok(client) => {
decode_grpc_clients.insert(url.clone(), client);
info!("Connected to gRPC decode worker at {}", url);
}
Err(e) => {
warn!("Failed to connect to gRPC decode worker at {}: {}", url, e);
// Continue with other workers
}
}
}
if prefill_grpc_clients.is_empty() && decode_grpc_clients.is_empty() {
return Err("Failed to connect to any gRPC workers".to_string());
}
// Create Prefill Worker trait objects with gRPC connection mode
let prefill_workers: Vec<Box<dyn Worker>> = prefill_urls
.iter()
.map(|(url, bootstrap_port)| {
let worker = BasicWorker::with_connection_mode(
url.clone(),
WorkerType::Prefill {
bootstrap_port: *bootstrap_port,
},
crate::core::ConnectionMode::Grpc {
port: *bootstrap_port,
},
)
.with_circuit_breaker_config(core_cb_config.clone())
.with_health_config(HealthConfig {
timeout_secs: ctx.router_config.health_check.timeout_secs,
check_interval_secs: ctx.router_config.health_check.check_interval_secs,
endpoint: ctx.router_config.health_check.endpoint.clone(),
failure_threshold: ctx.router_config.health_check.failure_threshold,
success_threshold: ctx.router_config.health_check.success_threshold,
});
Box::new(worker) as Box<dyn Worker>
})
.collect();
// Create Decode Worker trait objects with gRPC connection mode
let decode_workers: Vec<Box<dyn Worker>> = decode_urls
.iter()
.map(|url| {
let worker = BasicWorker::with_connection_mode(
url.clone(),
WorkerType::Decode,
crate::core::ConnectionMode::Grpc { port: None },
)
.with_circuit_breaker_config(core_cb_config.clone())
.with_health_config(HealthConfig {
timeout_secs: ctx.router_config.health_check.timeout_secs,
check_interval_secs: ctx.router_config.health_check.check_interval_secs,
endpoint: ctx.router_config.health_check.endpoint.clone(),
failure_threshold: ctx.router_config.health_check.failure_threshold,
success_threshold: ctx.router_config.health_check.success_threshold,
});
Box::new(worker) as Box<dyn Worker>
})
.collect();
// Initialize policies with workers if needed
if let Some(cache_aware) = prefill_policy
.as_any()
.downcast_ref::<crate::policies::CacheAwarePolicy>()
{
cache_aware.init_workers(&prefill_workers);
}
if let Some(cache_aware) = decode_policy
.as_any()
.downcast_ref::<crate::policies::CacheAwarePolicy>()
{
cache_aware.init_workers(&decode_workers);
}
let prefill_workers = Arc::new(RwLock::new(prefill_workers));
let decode_workers = Arc::new(RwLock::new(decode_workers));
let prefill_health_checker = crate::core::start_health_checker(
Arc::clone(&prefill_workers),
ctx.router_config.worker_startup_check_interval_secs,
);
let decode_health_checker = crate::core::start_health_checker(
Arc::clone(&decode_workers),
ctx.router_config.worker_startup_check_interval_secs,
);
Ok(GrpcPDRouter {
prefill_workers,
decode_workers,
prefill_grpc_clients: Arc::new(RwLock::new(prefill_grpc_clients)),
decode_grpc_clients: Arc::new(RwLock::new(decode_grpc_clients)),
prefill_policy,
decode_policy,
tokenizer,
reasoning_parser_factory,
tool_parser_registry,
_prefill_health_checker: Some(prefill_health_checker),
_decode_health_checker: Some(decode_health_checker),
timeout_secs: ctx.router_config.worker_startup_timeout_secs,
interval_secs: ctx.router_config.worker_startup_check_interval_secs,
dp_aware: ctx.router_config.dp_aware,
api_key: ctx.router_config.api_key.clone(),
retry_config: ctx.router_config.effective_retry_config(),
circuit_breaker_config: core_cb_config,
})
}
}
impl std::fmt::Debug for GrpcPDRouter {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("GrpcPDRouter")
.field(
"prefill_workers_count",
&self.prefill_workers.read().unwrap().len(),
)
.field(
"decode_workers_count",
&self.decode_workers.read().unwrap().len(),
)
.field("timeout_secs", &self.timeout_secs)
.field("interval_secs", &self.interval_secs)
.field("dp_aware", &self.dp_aware)
.finish()
}
}
#[async_trait]
impl RouterTrait for GrpcPDRouter {
fn as_any(&self) -> &dyn std::any::Any {
self
}
async fn health(&self, _req: Request<Body>) -> Response {
(StatusCode::NOT_IMPLEMENTED).into_response()
}
async fn health_generate(&self, _req: Request<Body>) -> Response {
(StatusCode::NOT_IMPLEMENTED).into_response()
}
async fn get_server_info(&self, _req: Request<Body>) -> Response {
(StatusCode::NOT_IMPLEMENTED).into_response()
}
async fn get_models(&self, _req: Request<Body>) -> Response {
(StatusCode::NOT_IMPLEMENTED).into_response()
}
async fn get_model_info(&self, _req: Request<Body>) -> Response {
(StatusCode::NOT_IMPLEMENTED).into_response()
}
async fn route_generate(
&self,
_headers: Option<&HeaderMap>,
_body: &crate::protocols::spec::GenerateRequest,
) -> Response {
(StatusCode::NOT_IMPLEMENTED).into_response()
}
async fn route_chat(
&self,
_headers: Option<&HeaderMap>,
_body: &crate::protocols::spec::ChatCompletionRequest,
) -> Response {
(StatusCode::NOT_IMPLEMENTED).into_response()
}
async fn route_completion(
&self,
_headers: Option<&HeaderMap>,
_body: &crate::protocols::spec::CompletionRequest,
) -> Response {
(StatusCode::NOT_IMPLEMENTED).into_response()
}
async fn route_embeddings(&self, _headers: Option<&HeaderMap>, _body: Body) -> Response {
(StatusCode::NOT_IMPLEMENTED).into_response()
}
async fn route_rerank(&self, _headers: Option<&HeaderMap>, _body: Body) -> Response {
(StatusCode::NOT_IMPLEMENTED).into_response()
}
async fn flush_cache(&self) -> Response {
(StatusCode::NOT_IMPLEMENTED).into_response()
}
async fn get_worker_loads(&self) -> Response {
(StatusCode::NOT_IMPLEMENTED).into_response()
}
fn router_type(&self) -> &'static str {
"grpc_pd"
}
fn readiness(&self) -> Response {
(StatusCode::SERVICE_UNAVAILABLE).into_response()
}
}
#[async_trait]
impl WorkerManagement for GrpcPDRouter {
async fn add_worker(&self, _worker_url: &str) -> Result<String, String> {
Err("Not implemented".to_string())
}
fn remove_worker(&self, _worker_url: &str) {}
fn get_worker_urls(&self) -> Vec<String> {
vec![]
}
}

View File

@@ -0,0 +1,266 @@
// gRPC Router Implementation
use crate::config::types::RetryConfig;
use crate::core::{
BasicWorker, CircuitBreakerConfig, HealthChecker, HealthConfig, Worker, WorkerType,
};
use crate::grpc::SglangSchedulerClient;
use crate::metrics::RouterMetrics;
use crate::policies::LoadBalancingPolicy;
use crate::reasoning_parser::ParserFactory;
use crate::routers::{RouterTrait, WorkerManagement};
use crate::tokenizer::traits::Tokenizer;
use crate::tool_parser::ParserRegistry;
use async_trait::async_trait;
use axum::{
body::Body,
extract::Request,
http::{HeaderMap, StatusCode},
response::{IntoResponse, Response},
};
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
use std::time::Duration;
use tracing::{info, warn};
/// gRPC router implementation for SGLang
#[allow(dead_code)] // Fields will be used once implementation is complete
pub struct GrpcRouter {
/// Worker connections
workers: Arc<RwLock<Vec<Box<dyn Worker>>>>,
/// gRPC clients for each worker
grpc_clients: Arc<RwLock<HashMap<String, SglangSchedulerClient>>>,
/// Load balancing policy
policy: Arc<dyn LoadBalancingPolicy>,
/// Tokenizer for handling text encoding/decoding
tokenizer: Arc<dyn Tokenizer>,
/// Reasoning parser factory for structured reasoning outputs
reasoning_parser_factory: ParserFactory,
/// Tool parser registry for function/tool calls
tool_parser_registry: &'static ParserRegistry,
/// Worker health checker
_health_checker: Option<HealthChecker>,
/// Configuration
timeout_secs: u64,
interval_secs: u64,
dp_aware: bool,
api_key: Option<String>,
retry_config: RetryConfig,
circuit_breaker_config: CircuitBreakerConfig,
}
impl GrpcRouter {
/// Create a new gRPC router
pub async fn new(
worker_urls: Vec<String>,
policy: Arc<dyn LoadBalancingPolicy>,
ctx: &Arc<crate::server::AppContext>,
) -> Result<Self, String> {
// Update metrics
RouterMetrics::set_active_workers(worker_urls.len());
// Extract necessary components from context
let tokenizer = ctx
.tokenizer
.as_ref()
.ok_or_else(|| "gRPC router requires tokenizer".to_string())?
.clone();
let reasoning_parser_factory = ctx
.reasoning_parser_factory
.as_ref()
.ok_or_else(|| "gRPC router requires reasoning parser factory".to_string())?
.clone();
let tool_parser_registry = ctx
.tool_parser_registry
.ok_or_else(|| "gRPC router requires tool parser registry".to_string())?;
// Convert config CircuitBreakerConfig to core CircuitBreakerConfig
let circuit_breaker_config = ctx.router_config.effective_circuit_breaker_config();
let core_cb_config = CircuitBreakerConfig {
failure_threshold: circuit_breaker_config.failure_threshold,
success_threshold: circuit_breaker_config.success_threshold,
timeout_duration: Duration::from_secs(circuit_breaker_config.timeout_duration_secs),
window_duration: Duration::from_secs(circuit_breaker_config.window_duration_secs),
};
// Create gRPC clients for each worker
let mut grpc_clients = HashMap::new();
for url in &worker_urls {
match SglangSchedulerClient::connect(url).await {
Ok(client) => {
grpc_clients.insert(url.clone(), client);
info!("Connected to gRPC worker at {}", url);
}
Err(e) => {
warn!("Failed to connect to gRPC worker at {}: {}", url, e);
// Continue with other workers
}
}
}
if grpc_clients.is_empty() {
return Err("Failed to connect to any gRPC workers".to_string());
}
// Create Worker trait objects with gRPC connection mode
let mut workers: Vec<Box<dyn Worker>> = Vec::new();
// Move clients from the HashMap to the workers
for url in &worker_urls {
if let Some(client) = grpc_clients.remove(url) {
let worker = BasicWorker::with_connection_mode(
url.clone(),
WorkerType::Regular,
crate::core::ConnectionMode::Grpc { port: None },
)
.with_circuit_breaker_config(core_cb_config.clone())
.with_health_config(HealthConfig {
timeout_secs: ctx.router_config.health_check.timeout_secs,
check_interval_secs: ctx.router_config.health_check.check_interval_secs,
endpoint: ctx.router_config.health_check.endpoint.clone(),
failure_threshold: ctx.router_config.health_check.failure_threshold,
success_threshold: ctx.router_config.health_check.success_threshold,
})
.with_grpc_client(client);
workers.push(Box::new(worker) as Box<dyn Worker>);
} else {
warn!("No gRPC client for worker {}, skipping", url);
}
}
// Initialize policy with workers if needed
if let Some(cache_aware) = policy
.as_any()
.downcast_ref::<crate::policies::CacheAwarePolicy>()
{
cache_aware.init_workers(&workers);
}
let workers = Arc::new(RwLock::new(workers));
let health_checker = crate::core::start_health_checker(
Arc::clone(&workers),
ctx.router_config.worker_startup_check_interval_secs,
);
Ok(GrpcRouter {
workers,
grpc_clients: Arc::new(RwLock::new(grpc_clients)),
policy,
tokenizer,
reasoning_parser_factory,
tool_parser_registry,
_health_checker: Some(health_checker),
timeout_secs: ctx.router_config.worker_startup_timeout_secs,
interval_secs: ctx.router_config.worker_startup_check_interval_secs,
dp_aware: ctx.router_config.dp_aware,
api_key: ctx.router_config.api_key.clone(),
retry_config: ctx.router_config.effective_retry_config(),
circuit_breaker_config: core_cb_config,
})
}
}
impl std::fmt::Debug for GrpcRouter {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("GrpcRouter")
.field("workers_count", &self.workers.read().unwrap().len())
.field("timeout_secs", &self.timeout_secs)
.field("interval_secs", &self.interval_secs)
.field("dp_aware", &self.dp_aware)
.finish()
}
}
#[async_trait]
impl RouterTrait for GrpcRouter {
fn as_any(&self) -> &dyn std::any::Any {
self
}
async fn health(&self, _req: Request<Body>) -> Response {
(StatusCode::NOT_IMPLEMENTED).into_response()
}
async fn health_generate(&self, _req: Request<Body>) -> Response {
(StatusCode::NOT_IMPLEMENTED).into_response()
}
async fn get_server_info(&self, _req: Request<Body>) -> Response {
(StatusCode::NOT_IMPLEMENTED).into_response()
}
async fn get_models(&self, _req: Request<Body>) -> Response {
(StatusCode::NOT_IMPLEMENTED).into_response()
}
async fn get_model_info(&self, _req: Request<Body>) -> Response {
(StatusCode::NOT_IMPLEMENTED).into_response()
}
async fn route_generate(
&self,
_headers: Option<&HeaderMap>,
_body: &crate::protocols::spec::GenerateRequest,
) -> Response {
(StatusCode::NOT_IMPLEMENTED).into_response()
}
async fn route_chat(
&self,
_headers: Option<&HeaderMap>,
_body: &crate::protocols::spec::ChatCompletionRequest,
) -> Response {
(StatusCode::NOT_IMPLEMENTED).into_response()
}
async fn route_completion(
&self,
_headers: Option<&HeaderMap>,
_body: &crate::protocols::spec::CompletionRequest,
) -> Response {
(StatusCode::NOT_IMPLEMENTED).into_response()
}
async fn route_embeddings(&self, _headers: Option<&HeaderMap>, _body: Body) -> Response {
(StatusCode::NOT_IMPLEMENTED).into_response()
}
async fn route_rerank(&self, _headers: Option<&HeaderMap>, _body: Body) -> Response {
(StatusCode::NOT_IMPLEMENTED).into_response()
}
async fn flush_cache(&self) -> Response {
(StatusCode::NOT_IMPLEMENTED).into_response()
}
async fn get_worker_loads(&self) -> Response {
(StatusCode::NOT_IMPLEMENTED).into_response()
}
fn router_type(&self) -> &'static str {
"grpc"
}
fn readiness(&self) -> Response {
(StatusCode::SERVICE_UNAVAILABLE).into_response()
}
}
#[async_trait]
impl WorkerManagement for GrpcRouter {
async fn add_worker(&self, _worker_url: &str) -> Result<String, String> {
Err("Not implemented".to_string())
}
fn remove_worker(&self, _worker_url: &str) {}
fn get_worker_urls(&self) -> Vec<String> {
self.workers
.read()
.unwrap()
.iter()
.map(|w| w.url().to_string())
.collect()
}
}

View File

@@ -0,0 +1,53 @@
use axum::body::Body;
use axum::extract::Request;
use axum::http::HeaderMap;
/// Copy request headers to a Vec of name-value string pairs
/// Used for forwarding headers to backend workers
pub fn copy_request_headers(req: &Request<Body>) -> Vec<(String, String)> {
req.headers()
.iter()
.filter_map(|(name, value)| {
// Convert header value to string, skipping non-UTF8 headers
value
.to_str()
.ok()
.map(|v| (name.to_string(), v.to_string()))
})
.collect()
}
/// Convert headers from reqwest Response to axum HeaderMap
/// Filters out hop-by-hop headers that shouldn't be forwarded
pub fn preserve_response_headers(reqwest_headers: &HeaderMap) -> HeaderMap {
let mut headers = HeaderMap::new();
for (name, value) in reqwest_headers.iter() {
// Skip hop-by-hop headers that shouldn't be forwarded
let name_str = name.as_str().to_lowercase();
if should_forward_header(&name_str) {
// The original name and value are already valid, so we can just clone them
headers.insert(name.clone(), value.clone());
}
}
headers
}
/// Determine if a header should be forwarded from backend to client
fn should_forward_header(name: &str) -> bool {
// List of headers that should NOT be forwarded (hop-by-hop headers)
!matches!(
name,
"connection" |
"keep-alive" |
"proxy-authenticate" |
"proxy-authorization" |
"te" |
"trailers" |
"transfer-encoding" |
"upgrade" |
"content-encoding" | // Let axum/hyper handle encoding
"host" // Should not forward the backend's host header
)
}

View File

@@ -0,0 +1,6 @@
//! HTTP router implementations
pub mod openai_router;
pub mod pd_router;
pub mod pd_types;
pub mod router;

View File

@@ -0,0 +1,379 @@
//! OpenAI router implementation (reqwest-based)
use crate::config::CircuitBreakerConfig;
use crate::core::{CircuitBreaker, CircuitBreakerConfig as CoreCircuitBreakerConfig};
use crate::protocols::spec::{ChatCompletionRequest, CompletionRequest, GenerateRequest};
use async_trait::async_trait;
use axum::{
body::Body,
extract::Request,
http::{header::CONTENT_TYPE, HeaderMap, HeaderValue, StatusCode},
response::{IntoResponse, Response},
};
use futures_util::StreamExt;
use std::{
any::Any,
sync::atomic::{AtomicBool, Ordering},
};
/// Router for OpenAI backend
#[derive(Debug)]
pub struct OpenAIRouter {
/// HTTP client for upstream OpenAI-compatible API
client: reqwest::Client,
/// Base URL for identification (no trailing slash)
base_url: String,
/// Circuit breaker
circuit_breaker: CircuitBreaker,
/// Health status
healthy: AtomicBool,
}
impl OpenAIRouter {
/// Create a new OpenAI router
pub async fn new(
base_url: String,
circuit_breaker_config: Option<CircuitBreakerConfig>,
) -> Result<Self, String> {
let client = reqwest::Client::builder()
.timeout(std::time::Duration::from_secs(300))
.build()
.map_err(|e| format!("Failed to create HTTP client: {}", e))?;
let base_url = base_url.trim_end_matches('/').to_string();
// Convert circuit breaker config
let core_cb_config = circuit_breaker_config
.map(|cb| CoreCircuitBreakerConfig {
failure_threshold: cb.failure_threshold,
success_threshold: cb.success_threshold,
timeout_duration: std::time::Duration::from_secs(cb.timeout_duration_secs),
window_duration: std::time::Duration::from_secs(cb.window_duration_secs),
})
.unwrap_or_default();
let circuit_breaker = CircuitBreaker::with_config(core_cb_config);
Ok(Self {
client,
base_url,
circuit_breaker,
healthy: AtomicBool::new(true),
})
}
}
#[async_trait]
impl super::super::WorkerManagement for OpenAIRouter {
async fn add_worker(&self, _worker_url: &str) -> Result<String, String> {
Err("Cannot add workers to OpenAI router".to_string())
}
fn remove_worker(&self, _worker_url: &str) {
// No-op for OpenAI router
}
fn get_worker_urls(&self) -> Vec<String> {
vec![self.base_url.clone()]
}
}
#[async_trait]
impl super::super::RouterTrait for OpenAIRouter {
fn as_any(&self) -> &dyn Any {
self
}
async fn health(&self, _req: Request<Body>) -> Response {
// Simple upstream probe: GET {base}/v1/models without auth
let url = format!("{}/v1/models", self.base_url);
match self
.client
.get(&url)
.timeout(std::time::Duration::from_secs(2))
.send()
.await
{
Ok(resp) => {
let code = resp.status();
// Treat success and auth-required as healthy (endpoint reachable)
if code.is_success() || code.as_u16() == 401 || code.as_u16() == 403 {
(StatusCode::OK, "OK").into_response()
} else {
(
StatusCode::SERVICE_UNAVAILABLE,
format!("Upstream status: {}", code),
)
.into_response()
}
}
Err(e) => (
StatusCode::SERVICE_UNAVAILABLE,
format!("Upstream error: {}", e),
)
.into_response(),
}
}
async fn health_generate(&self, _req: Request<Body>) -> Response {
// For OpenAI, health_generate is the same as health
self.health(_req).await
}
async fn get_server_info(&self, _req: Request<Body>) -> Response {
let info = serde_json::json!({
"router_type": "openai",
"workers": 1,
"base_url": &self.base_url
});
(StatusCode::OK, info.to_string()).into_response()
}
async fn get_models(&self, req: Request<Body>) -> Response {
// Proxy to upstream /v1/models; forward Authorization header if provided
let headers = req.headers();
let mut upstream = self.client.get(format!("{}/v1/models", self.base_url));
if let Some(auth) = headers
.get("authorization")
.or_else(|| headers.get("Authorization"))
{
upstream = upstream.header("Authorization", auth);
}
match upstream.send().await {
Ok(res) => {
let status = StatusCode::from_u16(res.status().as_u16())
.unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
let content_type = res.headers().get(CONTENT_TYPE).cloned();
match res.bytes().await {
Ok(body) => {
let mut response = Response::new(axum::body::Body::from(body));
*response.status_mut() = status;
if let Some(ct) = content_type {
response.headers_mut().insert(CONTENT_TYPE, ct);
}
response
}
Err(e) => (
StatusCode::INTERNAL_SERVER_ERROR,
format!("Failed to read upstream response: {}", e),
)
.into_response(),
}
}
Err(e) => (
StatusCode::BAD_GATEWAY,
format!("Failed to contact upstream: {}", e),
)
.into_response(),
}
}
async fn get_model_info(&self, _req: Request<Body>) -> Response {
// Not directly supported without model param; return 501
(
StatusCode::NOT_IMPLEMENTED,
"get_model_info not implemented for OpenAI router",
)
.into_response()
}
async fn route_generate(
&self,
_headers: Option<&HeaderMap>,
_body: &GenerateRequest,
) -> Response {
// Generate endpoint is SGLang-specific, not supported for OpenAI backend
(
StatusCode::NOT_IMPLEMENTED,
"Generate endpoint not supported for OpenAI backend",
)
.into_response()
}
async fn route_chat(
&self,
headers: Option<&HeaderMap>,
body: &ChatCompletionRequest,
) -> Response {
if !self.circuit_breaker.can_execute() {
return (StatusCode::SERVICE_UNAVAILABLE, "Circuit breaker open").into_response();
}
// Serialize request body, removing SGLang-only fields
let mut payload = match serde_json::to_value(body) {
Ok(v) => v,
Err(e) => {
return (
StatusCode::BAD_REQUEST,
format!("Failed to serialize request: {}", e),
)
.into_response();
}
};
if let Some(obj) = payload.as_object_mut() {
for key in [
"top_k",
"min_p",
"min_tokens",
"regex",
"ebnf",
"stop_token_ids",
"no_stop_trim",
"ignore_eos",
"continue_final_message",
"skip_special_tokens",
"lora_path",
"session_params",
"separate_reasoning",
"stream_reasoning",
"chat_template_kwargs",
"return_hidden_states",
"repetition_penalty",
] {
obj.remove(key);
}
}
let url = format!("{}/v1/chat/completions", self.base_url);
let mut req = self.client.post(&url).json(&payload);
// Forward Authorization header if provided
if let Some(h) = headers {
if let Some(auth) = h.get("authorization").or_else(|| h.get("Authorization")) {
req = req.header("Authorization", auth);
}
}
// Accept SSE when stream=true
if body.stream {
req = req.header("Accept", "text/event-stream");
}
let resp = match req.send().await {
Ok(r) => r,
Err(e) => {
self.circuit_breaker.record_failure();
return (
StatusCode::SERVICE_UNAVAILABLE,
format!("Failed to contact upstream: {}", e),
)
.into_response();
}
};
let status = StatusCode::from_u16(resp.status().as_u16())
.unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
if !body.stream {
// Capture Content-Type before consuming response body
let content_type = resp.headers().get(CONTENT_TYPE).cloned();
match resp.bytes().await {
Ok(body) => {
self.circuit_breaker.record_success();
let mut response = Response::new(axum::body::Body::from(body));
*response.status_mut() = status;
if let Some(ct) = content_type {
response.headers_mut().insert(CONTENT_TYPE, ct);
}
response
}
Err(e) => {
self.circuit_breaker.record_failure();
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Failed to read response: {}", e),
)
.into_response()
}
}
} else {
// Stream SSE bytes to client
let stream = resp.bytes_stream();
let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
tokio::spawn(async move {
let mut s = stream;
while let Some(chunk) = s.next().await {
match chunk {
Ok(bytes) => {
if tx.send(Ok(bytes)).is_err() {
break;
}
}
Err(e) => {
let _ = tx.send(Err(format!("Stream error: {}", e)));
break;
}
}
}
});
let mut response = Response::new(Body::from_stream(
tokio_stream::wrappers::UnboundedReceiverStream::new(rx),
));
*response.status_mut() = status;
response
.headers_mut()
.insert(CONTENT_TYPE, HeaderValue::from_static("text/event-stream"));
response
}
}
async fn route_completion(
&self,
_headers: Option<&HeaderMap>,
_body: &CompletionRequest,
) -> Response {
// Completion endpoint not implemented for OpenAI backend
(
StatusCode::NOT_IMPLEMENTED,
"Completion endpoint not implemented for OpenAI backend",
)
.into_response()
}
async fn flush_cache(&self) -> Response {
(
StatusCode::NOT_IMPLEMENTED,
"flush_cache not supported for OpenAI router",
)
.into_response()
}
async fn get_worker_loads(&self) -> Response {
(
StatusCode::NOT_IMPLEMENTED,
"get_worker_loads not supported for OpenAI router",
)
.into_response()
}
fn router_type(&self) -> &'static str {
"openai"
}
fn readiness(&self) -> Response {
if self.healthy.load(Ordering::Acquire) && self.circuit_breaker.can_execute() {
(StatusCode::OK, "Ready").into_response()
} else {
(StatusCode::SERVICE_UNAVAILABLE, "Not ready").into_response()
}
}
async fn route_embeddings(&self, _headers: Option<&HeaderMap>, _body: Body) -> Response {
(
StatusCode::NOT_IMPLEMENTED,
"Embeddings endpoint not implemented for OpenAI backend",
)
.into_response()
}
async fn route_rerank(&self, _headers: Option<&HeaderMap>, _body: Body) -> Response {
(
StatusCode::NOT_IMPLEMENTED,
"Rerank endpoint not implemented for OpenAI backend",
)
.into_response()
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,81 @@
// Custom error type for PD router operations
#[derive(Debug, thiserror::Error)]
pub enum PDRouterError {
#[error("Worker already exists: {url}")]
WorkerAlreadyExists { url: String },
#[error("Worker not found: {url}")]
WorkerNotFound { url: String },
#[error("Lock acquisition failed: {operation}")]
LockError { operation: String },
#[error("Health check failed for worker: {url}")]
HealthCheckFailed { url: String },
#[error("Invalid worker configuration: {reason}")]
InvalidConfiguration { reason: String },
#[error("Network error: {message}")]
NetworkError { message: String },
#[error("Timeout waiting for worker: {url}")]
Timeout { url: String },
}
// Helper functions for workers
pub fn api_path(url: &str, api_path: &str) -> String {
if api_path.starts_with("/") {
format!("{}{}", url, api_path)
} else {
format!("{}/{}", url, api_path)
}
}
pub fn get_hostname(url: &str) -> String {
// Simple hostname extraction without external dependencies
let url = url
.trim_start_matches("http://")
.trim_start_matches("https://");
url.split(':').next().unwrap_or("localhost").to_string()
}
use serde::Serialize;
// Optimized bootstrap wrapper for single requests
#[derive(Serialize)]
pub struct RequestWithBootstrap<'a, T: Serialize> {
#[serde(flatten)]
pub original: &'a T,
pub bootstrap_host: String,
pub bootstrap_port: Option<u16>,
pub bootstrap_room: u64,
}
// Optimized bootstrap wrapper for batch requests
#[derive(Serialize)]
pub struct BatchRequestWithBootstrap<'a, T: Serialize> {
#[serde(flatten)]
pub original: &'a T,
pub bootstrap_host: Vec<String>,
pub bootstrap_port: Vec<Option<u16>>,
pub bootstrap_room: Vec<u64>,
}
// Helper to generate bootstrap room ID
pub fn generate_room_id() -> u64 {
// Generate a value in the range [0, 2^63 - 1] to match Python's random.randint(0, 2**63 - 1)
rand::random::<u64>() & (i64::MAX as u64)
}
// PD-specific routing policies
#[derive(Debug, Clone, PartialEq)]
pub enum PDSelectionPolicy {
Random,
PowerOfTwo,
CacheAware {
cache_threshold: f32,
balance_abs_threshold: usize,
balance_rel_threshold: f32,
},
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,107 @@
//! Router implementations
use async_trait::async_trait;
use axum::{
body::Body,
extract::Request,
http::{HeaderMap, StatusCode},
response::{IntoResponse, Response},
};
use std::fmt::Debug;
use crate::protocols::spec::{ChatCompletionRequest, CompletionRequest, GenerateRequest};
pub mod factory;
pub mod grpc;
pub mod header_utils;
pub mod http;
pub use factory::RouterFactory;
// Re-export HTTP routers for convenience (keeps routers::openai_router path working)
pub use http::{openai_router, pd_router, pd_types, router};
/// Worker management trait for administrative operations
///
/// This trait is separate from RouterTrait to allow Send futures
/// for use in service discovery and other background tasks
#[async_trait]
pub trait WorkerManagement: Send + Sync {
/// Add a worker to the router
async fn add_worker(&self, worker_url: &str) -> Result<String, String>;
/// Remove a worker from the router
fn remove_worker(&self, worker_url: &str);
/// Get all worker URLs
fn get_worker_urls(&self) -> Vec<String>;
}
/// Core trait for all router implementations
///
/// This trait provides a unified interface for routing requests,
/// regardless of whether it's a regular router or PD router.
#[async_trait]
pub trait RouterTrait: Send + Sync + Debug + WorkerManagement {
/// Get a reference to self as Any for downcasting
fn as_any(&self) -> &dyn std::any::Any;
/// Route a health check request
async fn health(&self, req: Request<Body>) -> Response;
/// Route a health generate request
async fn health_generate(&self, req: Request<Body>) -> Response;
/// Get server information
async fn get_server_info(&self, req: Request<Body>) -> Response;
/// Get available models
async fn get_models(&self, req: Request<Body>) -> Response;
/// Get model information
async fn get_model_info(&self, req: Request<Body>) -> Response;
/// Route a generate request
async fn route_generate(&self, headers: Option<&HeaderMap>, body: &GenerateRequest)
-> Response;
/// Route a chat completion request
async fn route_chat(
&self,
headers: Option<&HeaderMap>,
body: &ChatCompletionRequest,
) -> Response;
/// Route a completion request
async fn route_completion(
&self,
headers: Option<&HeaderMap>,
body: &CompletionRequest,
) -> Response;
async fn route_embeddings(&self, headers: Option<&HeaderMap>, body: Body) -> Response;
async fn route_rerank(&self, headers: Option<&HeaderMap>, body: Body) -> Response;
/// Flush cache on all workers
async fn flush_cache(&self) -> Response;
/// Get worker loads (for monitoring)
async fn get_worker_loads(&self) -> Response;
/// Get router type name
fn router_type(&self) -> &'static str;
/// Check if this is a PD router
fn is_pd_mode(&self) -> bool {
self.router_type() == "pd"
}
/// Server liveness check - is the server process running
fn liveness(&self) -> Response {
// Simple liveness check - if we can respond, we're alive
(StatusCode::OK, "OK").into_response()
}
/// Server readiness check - is the server ready to handle requests
fn readiness(&self) -> Response;
}