sglangv0.5.2 & support Qwen3-Next-80B-A3B-Instruct
This commit is contained in:
195
sgl-router/src/routers/factory.rs
Normal file
195
sgl-router/src/routers/factory.rs
Normal file
@@ -0,0 +1,195 @@
|
||||
//! Factory for creating router instances
|
||||
|
||||
use super::{
|
||||
http::{openai_router::OpenAIRouter, pd_router::PDRouter, router::Router},
|
||||
RouterTrait,
|
||||
};
|
||||
use crate::config::{ConnectionMode, PolicyConfig, RoutingMode};
|
||||
use crate::policies::PolicyFactory;
|
||||
use crate::server::AppContext;
|
||||
use std::sync::Arc;
|
||||
|
||||
/// Factory for creating router instances based on configuration
|
||||
pub struct RouterFactory;
|
||||
|
||||
impl RouterFactory {
|
||||
/// Create a router instance from application context
|
||||
pub async fn create_router(ctx: &Arc<AppContext>) -> Result<Box<dyn RouterTrait>, String> {
|
||||
// Check if IGW mode is enabled
|
||||
if ctx.router_config.enable_igw {
|
||||
return Self::create_igw_router(ctx).await;
|
||||
}
|
||||
|
||||
// Check connection mode and route to appropriate implementation
|
||||
match ctx.router_config.connection_mode {
|
||||
ConnectionMode::Grpc => {
|
||||
// Route to gRPC implementation based on routing mode
|
||||
match &ctx.router_config.mode {
|
||||
RoutingMode::Regular { worker_urls } => {
|
||||
Self::create_grpc_router(worker_urls, &ctx.router_config.policy, ctx).await
|
||||
}
|
||||
RoutingMode::PrefillDecode {
|
||||
prefill_urls,
|
||||
decode_urls,
|
||||
prefill_policy,
|
||||
decode_policy,
|
||||
} => {
|
||||
Self::create_grpc_pd_router(
|
||||
prefill_urls,
|
||||
decode_urls,
|
||||
prefill_policy.as_ref(),
|
||||
decode_policy.as_ref(),
|
||||
&ctx.router_config.policy,
|
||||
ctx,
|
||||
)
|
||||
.await
|
||||
}
|
||||
RoutingMode::OpenAI { .. } => {
|
||||
Err("OpenAI mode requires HTTP connection_mode".to_string())
|
||||
}
|
||||
}
|
||||
}
|
||||
ConnectionMode::Http => {
|
||||
// Route to HTTP implementation based on routing mode
|
||||
match &ctx.router_config.mode {
|
||||
RoutingMode::Regular { worker_urls } => {
|
||||
Self::create_regular_router(worker_urls, &ctx.router_config.policy, ctx)
|
||||
.await
|
||||
}
|
||||
RoutingMode::PrefillDecode {
|
||||
prefill_urls,
|
||||
decode_urls,
|
||||
prefill_policy,
|
||||
decode_policy,
|
||||
} => {
|
||||
Self::create_pd_router(
|
||||
prefill_urls,
|
||||
decode_urls,
|
||||
prefill_policy.as_ref(),
|
||||
decode_policy.as_ref(),
|
||||
&ctx.router_config.policy,
|
||||
ctx,
|
||||
)
|
||||
.await
|
||||
}
|
||||
RoutingMode::OpenAI { worker_urls, .. } => {
|
||||
Self::create_openai_router(worker_urls.clone(), ctx).await
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a regular router with injected policy
|
||||
async fn create_regular_router(
|
||||
worker_urls: &[String],
|
||||
policy_config: &PolicyConfig,
|
||||
ctx: &Arc<AppContext>,
|
||||
) -> Result<Box<dyn RouterTrait>, String> {
|
||||
// Create policy
|
||||
let policy = PolicyFactory::create_from_config(policy_config);
|
||||
|
||||
// Create regular router with injected policy and context
|
||||
let router = Router::new(worker_urls.to_vec(), policy, ctx).await?;
|
||||
|
||||
Ok(Box::new(router))
|
||||
}
|
||||
|
||||
/// Create a PD router with injected policy
|
||||
async fn create_pd_router(
|
||||
prefill_urls: &[(String, Option<u16>)],
|
||||
decode_urls: &[String],
|
||||
prefill_policy_config: Option<&PolicyConfig>,
|
||||
decode_policy_config: Option<&PolicyConfig>,
|
||||
main_policy_config: &PolicyConfig,
|
||||
ctx: &Arc<AppContext>,
|
||||
) -> Result<Box<dyn RouterTrait>, String> {
|
||||
// Create policies - use specific policies if provided, otherwise fall back to main policy
|
||||
let prefill_policy =
|
||||
PolicyFactory::create_from_config(prefill_policy_config.unwrap_or(main_policy_config));
|
||||
let decode_policy =
|
||||
PolicyFactory::create_from_config(decode_policy_config.unwrap_or(main_policy_config));
|
||||
|
||||
// Create PD router with separate policies and context
|
||||
let router = PDRouter::new(
|
||||
prefill_urls.to_vec(),
|
||||
decode_urls.to_vec(),
|
||||
prefill_policy,
|
||||
decode_policy,
|
||||
ctx,
|
||||
)
|
||||
.await?;
|
||||
|
||||
Ok(Box::new(router))
|
||||
}
|
||||
|
||||
/// Create a gRPC router with injected policy
|
||||
pub async fn create_grpc_router(
|
||||
worker_urls: &[String],
|
||||
policy_config: &PolicyConfig,
|
||||
ctx: &Arc<AppContext>,
|
||||
) -> Result<Box<dyn RouterTrait>, String> {
|
||||
use super::grpc::router::GrpcRouter;
|
||||
|
||||
// Create policy
|
||||
let policy = PolicyFactory::create_from_config(policy_config);
|
||||
|
||||
// Create gRPC router with context
|
||||
let router = GrpcRouter::new(worker_urls.to_vec(), policy, ctx).await?;
|
||||
|
||||
Ok(Box::new(router))
|
||||
}
|
||||
|
||||
/// Create a gRPC PD router with tokenizer and worker configuration
|
||||
pub async fn create_grpc_pd_router(
|
||||
prefill_urls: &[(String, Option<u16>)],
|
||||
decode_urls: &[String],
|
||||
prefill_policy_config: Option<&PolicyConfig>,
|
||||
decode_policy_config: Option<&PolicyConfig>,
|
||||
main_policy_config: &PolicyConfig,
|
||||
ctx: &Arc<AppContext>,
|
||||
) -> Result<Box<dyn RouterTrait>, String> {
|
||||
use super::grpc::pd_router::GrpcPDRouter;
|
||||
|
||||
// Create policies - use specific policies if provided, otherwise fall back to main policy
|
||||
let prefill_policy =
|
||||
PolicyFactory::create_from_config(prefill_policy_config.unwrap_or(main_policy_config));
|
||||
let decode_policy =
|
||||
PolicyFactory::create_from_config(decode_policy_config.unwrap_or(main_policy_config));
|
||||
|
||||
// Create gRPC PD router with context
|
||||
let router = GrpcPDRouter::new(
|
||||
prefill_urls.to_vec(),
|
||||
decode_urls.to_vec(),
|
||||
prefill_policy,
|
||||
decode_policy,
|
||||
ctx,
|
||||
)
|
||||
.await?;
|
||||
|
||||
Ok(Box::new(router))
|
||||
}
|
||||
|
||||
/// Create an OpenAI router
|
||||
async fn create_openai_router(
|
||||
worker_urls: Vec<String>,
|
||||
ctx: &Arc<AppContext>,
|
||||
) -> Result<Box<dyn RouterTrait>, String> {
|
||||
// Use the first worker URL as the OpenAI-compatible base
|
||||
let base_url = worker_urls
|
||||
.first()
|
||||
.cloned()
|
||||
.ok_or_else(|| "OpenAI mode requires at least one worker URL".to_string())?;
|
||||
|
||||
let router =
|
||||
OpenAIRouter::new(base_url, Some(ctx.router_config.circuit_breaker.clone())).await?;
|
||||
|
||||
Ok(Box::new(router))
|
||||
}
|
||||
|
||||
/// Create an IGW router (placeholder for future implementation)
|
||||
async fn create_igw_router(_ctx: &Arc<AppContext>) -> Result<Box<dyn RouterTrait>, String> {
|
||||
// For now, return an error indicating IGW is not yet implemented
|
||||
Err("IGW mode is not yet implemented".to_string())
|
||||
}
|
||||
}
|
||||
4
sgl-router/src/routers/grpc/mod.rs
Normal file
4
sgl-router/src/routers/grpc/mod.rs
Normal file
@@ -0,0 +1,4 @@
|
||||
//! gRPC router implementations
|
||||
|
||||
pub mod pd_router;
|
||||
pub mod router;
|
||||
328
sgl-router/src/routers/grpc/pd_router.rs
Normal file
328
sgl-router/src/routers/grpc/pd_router.rs
Normal file
@@ -0,0 +1,328 @@
|
||||
// PD (Prefill-Decode) gRPC Router Implementation
|
||||
|
||||
use crate::config::types::RetryConfig;
|
||||
use crate::core::{
|
||||
BasicWorker, CircuitBreakerConfig, HealthChecker, HealthConfig, Worker, WorkerType,
|
||||
};
|
||||
use crate::grpc::SglangSchedulerClient;
|
||||
use crate::metrics::RouterMetrics;
|
||||
use crate::policies::LoadBalancingPolicy;
|
||||
use crate::reasoning_parser::ParserFactory;
|
||||
use crate::routers::{RouterTrait, WorkerManagement};
|
||||
use crate::tokenizer::traits::Tokenizer;
|
||||
use crate::tool_parser::ParserRegistry;
|
||||
use async_trait::async_trait;
|
||||
use axum::{
|
||||
body::Body,
|
||||
extract::Request,
|
||||
http::{HeaderMap, StatusCode},
|
||||
response::{IntoResponse, Response},
|
||||
};
|
||||
use std::collections::HashMap;
|
||||
use std::sync::{Arc, RwLock};
|
||||
use std::time::Duration;
|
||||
use tracing::{info, warn};
|
||||
|
||||
/// gRPC PD (Prefill-Decode) router implementation for SGLang
|
||||
#[allow(dead_code)] // Fields will be used once implementation is complete
|
||||
pub struct GrpcPDRouter {
|
||||
/// Prefill worker connections
|
||||
prefill_workers: Arc<RwLock<Vec<Box<dyn Worker>>>>,
|
||||
/// Decode worker connections
|
||||
decode_workers: Arc<RwLock<Vec<Box<dyn Worker>>>>,
|
||||
/// gRPC clients for prefill workers
|
||||
prefill_grpc_clients: Arc<RwLock<HashMap<String, SglangSchedulerClient>>>,
|
||||
/// gRPC clients for decode workers
|
||||
decode_grpc_clients: Arc<RwLock<HashMap<String, SglangSchedulerClient>>>,
|
||||
/// Load balancing policy for prefill
|
||||
prefill_policy: Arc<dyn LoadBalancingPolicy>,
|
||||
/// Load balancing policy for decode
|
||||
decode_policy: Arc<dyn LoadBalancingPolicy>,
|
||||
/// Tokenizer for handling text encoding/decoding
|
||||
tokenizer: Arc<dyn Tokenizer>,
|
||||
/// Reasoning parser factory for structured reasoning outputs
|
||||
reasoning_parser_factory: ParserFactory,
|
||||
/// Tool parser registry for function/tool calls
|
||||
tool_parser_registry: &'static ParserRegistry,
|
||||
/// Worker health checkers
|
||||
_prefill_health_checker: Option<HealthChecker>,
|
||||
_decode_health_checker: Option<HealthChecker>,
|
||||
/// Configuration
|
||||
timeout_secs: u64,
|
||||
interval_secs: u64,
|
||||
dp_aware: bool,
|
||||
api_key: Option<String>,
|
||||
retry_config: RetryConfig,
|
||||
circuit_breaker_config: CircuitBreakerConfig,
|
||||
}
|
||||
|
||||
impl GrpcPDRouter {
|
||||
/// Create a new gRPC PD router
|
||||
pub async fn new(
|
||||
prefill_urls: Vec<(String, Option<u16>)>,
|
||||
decode_urls: Vec<String>,
|
||||
prefill_policy: Arc<dyn LoadBalancingPolicy>,
|
||||
decode_policy: Arc<dyn LoadBalancingPolicy>,
|
||||
ctx: &Arc<crate::server::AppContext>,
|
||||
) -> Result<Self, String> {
|
||||
// Update metrics
|
||||
RouterMetrics::set_active_workers(prefill_urls.len() + decode_urls.len());
|
||||
|
||||
// Extract necessary components from context
|
||||
let tokenizer = ctx
|
||||
.tokenizer
|
||||
.as_ref()
|
||||
.ok_or_else(|| "gRPC PD router requires tokenizer".to_string())?
|
||||
.clone();
|
||||
let reasoning_parser_factory = ctx
|
||||
.reasoning_parser_factory
|
||||
.as_ref()
|
||||
.ok_or_else(|| "gRPC PD router requires reasoning parser factory".to_string())?
|
||||
.clone();
|
||||
let tool_parser_registry = ctx
|
||||
.tool_parser_registry
|
||||
.ok_or_else(|| "gRPC PD router requires tool parser registry".to_string())?;
|
||||
|
||||
// Convert config CircuitBreakerConfig to core CircuitBreakerConfig
|
||||
let circuit_breaker_config = ctx.router_config.effective_circuit_breaker_config();
|
||||
let core_cb_config = CircuitBreakerConfig {
|
||||
failure_threshold: circuit_breaker_config.failure_threshold,
|
||||
success_threshold: circuit_breaker_config.success_threshold,
|
||||
timeout_duration: Duration::from_secs(circuit_breaker_config.timeout_duration_secs),
|
||||
window_duration: Duration::from_secs(circuit_breaker_config.window_duration_secs),
|
||||
};
|
||||
|
||||
// Create gRPC clients for prefill workers
|
||||
let mut prefill_grpc_clients = HashMap::new();
|
||||
for (url, _bootstrap_port) in &prefill_urls {
|
||||
match SglangSchedulerClient::connect(url).await {
|
||||
Ok(client) => {
|
||||
prefill_grpc_clients.insert(url.clone(), client);
|
||||
info!("Connected to gRPC prefill worker at {}", url);
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Failed to connect to gRPC prefill worker at {}: {}", url, e);
|
||||
// Continue with other workers
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Create gRPC clients for decode workers
|
||||
let mut decode_grpc_clients = HashMap::new();
|
||||
for url in &decode_urls {
|
||||
match SglangSchedulerClient::connect(url).await {
|
||||
Ok(client) => {
|
||||
decode_grpc_clients.insert(url.clone(), client);
|
||||
info!("Connected to gRPC decode worker at {}", url);
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Failed to connect to gRPC decode worker at {}: {}", url, e);
|
||||
// Continue with other workers
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if prefill_grpc_clients.is_empty() && decode_grpc_clients.is_empty() {
|
||||
return Err("Failed to connect to any gRPC workers".to_string());
|
||||
}
|
||||
|
||||
// Create Prefill Worker trait objects with gRPC connection mode
|
||||
let prefill_workers: Vec<Box<dyn Worker>> = prefill_urls
|
||||
.iter()
|
||||
.map(|(url, bootstrap_port)| {
|
||||
let worker = BasicWorker::with_connection_mode(
|
||||
url.clone(),
|
||||
WorkerType::Prefill {
|
||||
bootstrap_port: *bootstrap_port,
|
||||
},
|
||||
crate::core::ConnectionMode::Grpc {
|
||||
port: *bootstrap_port,
|
||||
},
|
||||
)
|
||||
.with_circuit_breaker_config(core_cb_config.clone())
|
||||
.with_health_config(HealthConfig {
|
||||
timeout_secs: ctx.router_config.health_check.timeout_secs,
|
||||
check_interval_secs: ctx.router_config.health_check.check_interval_secs,
|
||||
endpoint: ctx.router_config.health_check.endpoint.clone(),
|
||||
failure_threshold: ctx.router_config.health_check.failure_threshold,
|
||||
success_threshold: ctx.router_config.health_check.success_threshold,
|
||||
});
|
||||
Box::new(worker) as Box<dyn Worker>
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Create Decode Worker trait objects with gRPC connection mode
|
||||
let decode_workers: Vec<Box<dyn Worker>> = decode_urls
|
||||
.iter()
|
||||
.map(|url| {
|
||||
let worker = BasicWorker::with_connection_mode(
|
||||
url.clone(),
|
||||
WorkerType::Decode,
|
||||
crate::core::ConnectionMode::Grpc { port: None },
|
||||
)
|
||||
.with_circuit_breaker_config(core_cb_config.clone())
|
||||
.with_health_config(HealthConfig {
|
||||
timeout_secs: ctx.router_config.health_check.timeout_secs,
|
||||
check_interval_secs: ctx.router_config.health_check.check_interval_secs,
|
||||
endpoint: ctx.router_config.health_check.endpoint.clone(),
|
||||
failure_threshold: ctx.router_config.health_check.failure_threshold,
|
||||
success_threshold: ctx.router_config.health_check.success_threshold,
|
||||
});
|
||||
Box::new(worker) as Box<dyn Worker>
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Initialize policies with workers if needed
|
||||
if let Some(cache_aware) = prefill_policy
|
||||
.as_any()
|
||||
.downcast_ref::<crate::policies::CacheAwarePolicy>()
|
||||
{
|
||||
cache_aware.init_workers(&prefill_workers);
|
||||
}
|
||||
|
||||
if let Some(cache_aware) = decode_policy
|
||||
.as_any()
|
||||
.downcast_ref::<crate::policies::CacheAwarePolicy>()
|
||||
{
|
||||
cache_aware.init_workers(&decode_workers);
|
||||
}
|
||||
|
||||
let prefill_workers = Arc::new(RwLock::new(prefill_workers));
|
||||
let decode_workers = Arc::new(RwLock::new(decode_workers));
|
||||
|
||||
let prefill_health_checker = crate::core::start_health_checker(
|
||||
Arc::clone(&prefill_workers),
|
||||
ctx.router_config.worker_startup_check_interval_secs,
|
||||
);
|
||||
let decode_health_checker = crate::core::start_health_checker(
|
||||
Arc::clone(&decode_workers),
|
||||
ctx.router_config.worker_startup_check_interval_secs,
|
||||
);
|
||||
|
||||
Ok(GrpcPDRouter {
|
||||
prefill_workers,
|
||||
decode_workers,
|
||||
prefill_grpc_clients: Arc::new(RwLock::new(prefill_grpc_clients)),
|
||||
decode_grpc_clients: Arc::new(RwLock::new(decode_grpc_clients)),
|
||||
prefill_policy,
|
||||
decode_policy,
|
||||
tokenizer,
|
||||
reasoning_parser_factory,
|
||||
tool_parser_registry,
|
||||
_prefill_health_checker: Some(prefill_health_checker),
|
||||
_decode_health_checker: Some(decode_health_checker),
|
||||
timeout_secs: ctx.router_config.worker_startup_timeout_secs,
|
||||
interval_secs: ctx.router_config.worker_startup_check_interval_secs,
|
||||
dp_aware: ctx.router_config.dp_aware,
|
||||
api_key: ctx.router_config.api_key.clone(),
|
||||
retry_config: ctx.router_config.effective_retry_config(),
|
||||
circuit_breaker_config: core_cb_config,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for GrpcPDRouter {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("GrpcPDRouter")
|
||||
.field(
|
||||
"prefill_workers_count",
|
||||
&self.prefill_workers.read().unwrap().len(),
|
||||
)
|
||||
.field(
|
||||
"decode_workers_count",
|
||||
&self.decode_workers.read().unwrap().len(),
|
||||
)
|
||||
.field("timeout_secs", &self.timeout_secs)
|
||||
.field("interval_secs", &self.interval_secs)
|
||||
.field("dp_aware", &self.dp_aware)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl RouterTrait for GrpcPDRouter {
|
||||
fn as_any(&self) -> &dyn std::any::Any {
|
||||
self
|
||||
}
|
||||
|
||||
async fn health(&self, _req: Request<Body>) -> Response {
|
||||
(StatusCode::NOT_IMPLEMENTED).into_response()
|
||||
}
|
||||
|
||||
async fn health_generate(&self, _req: Request<Body>) -> Response {
|
||||
(StatusCode::NOT_IMPLEMENTED).into_response()
|
||||
}
|
||||
|
||||
async fn get_server_info(&self, _req: Request<Body>) -> Response {
|
||||
(StatusCode::NOT_IMPLEMENTED).into_response()
|
||||
}
|
||||
|
||||
async fn get_models(&self, _req: Request<Body>) -> Response {
|
||||
(StatusCode::NOT_IMPLEMENTED).into_response()
|
||||
}
|
||||
|
||||
async fn get_model_info(&self, _req: Request<Body>) -> Response {
|
||||
(StatusCode::NOT_IMPLEMENTED).into_response()
|
||||
}
|
||||
|
||||
async fn route_generate(
|
||||
&self,
|
||||
_headers: Option<&HeaderMap>,
|
||||
_body: &crate::protocols::spec::GenerateRequest,
|
||||
) -> Response {
|
||||
(StatusCode::NOT_IMPLEMENTED).into_response()
|
||||
}
|
||||
|
||||
async fn route_chat(
|
||||
&self,
|
||||
_headers: Option<&HeaderMap>,
|
||||
_body: &crate::protocols::spec::ChatCompletionRequest,
|
||||
) -> Response {
|
||||
(StatusCode::NOT_IMPLEMENTED).into_response()
|
||||
}
|
||||
|
||||
async fn route_completion(
|
||||
&self,
|
||||
_headers: Option<&HeaderMap>,
|
||||
_body: &crate::protocols::spec::CompletionRequest,
|
||||
) -> Response {
|
||||
(StatusCode::NOT_IMPLEMENTED).into_response()
|
||||
}
|
||||
|
||||
async fn route_embeddings(&self, _headers: Option<&HeaderMap>, _body: Body) -> Response {
|
||||
(StatusCode::NOT_IMPLEMENTED).into_response()
|
||||
}
|
||||
|
||||
async fn route_rerank(&self, _headers: Option<&HeaderMap>, _body: Body) -> Response {
|
||||
(StatusCode::NOT_IMPLEMENTED).into_response()
|
||||
}
|
||||
|
||||
async fn flush_cache(&self) -> Response {
|
||||
(StatusCode::NOT_IMPLEMENTED).into_response()
|
||||
}
|
||||
|
||||
async fn get_worker_loads(&self) -> Response {
|
||||
(StatusCode::NOT_IMPLEMENTED).into_response()
|
||||
}
|
||||
|
||||
fn router_type(&self) -> &'static str {
|
||||
"grpc_pd"
|
||||
}
|
||||
|
||||
fn readiness(&self) -> Response {
|
||||
(StatusCode::SERVICE_UNAVAILABLE).into_response()
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl WorkerManagement for GrpcPDRouter {
|
||||
async fn add_worker(&self, _worker_url: &str) -> Result<String, String> {
|
||||
Err("Not implemented".to_string())
|
||||
}
|
||||
|
||||
fn remove_worker(&self, _worker_url: &str) {}
|
||||
|
||||
fn get_worker_urls(&self) -> Vec<String> {
|
||||
vec![]
|
||||
}
|
||||
}
|
||||
266
sgl-router/src/routers/grpc/router.rs
Normal file
266
sgl-router/src/routers/grpc/router.rs
Normal file
@@ -0,0 +1,266 @@
|
||||
// gRPC Router Implementation
|
||||
|
||||
use crate::config::types::RetryConfig;
|
||||
use crate::core::{
|
||||
BasicWorker, CircuitBreakerConfig, HealthChecker, HealthConfig, Worker, WorkerType,
|
||||
};
|
||||
use crate::grpc::SglangSchedulerClient;
|
||||
use crate::metrics::RouterMetrics;
|
||||
use crate::policies::LoadBalancingPolicy;
|
||||
use crate::reasoning_parser::ParserFactory;
|
||||
use crate::routers::{RouterTrait, WorkerManagement};
|
||||
use crate::tokenizer::traits::Tokenizer;
|
||||
use crate::tool_parser::ParserRegistry;
|
||||
use async_trait::async_trait;
|
||||
use axum::{
|
||||
body::Body,
|
||||
extract::Request,
|
||||
http::{HeaderMap, StatusCode},
|
||||
response::{IntoResponse, Response},
|
||||
};
|
||||
use std::collections::HashMap;
|
||||
use std::sync::{Arc, RwLock};
|
||||
use std::time::Duration;
|
||||
use tracing::{info, warn};
|
||||
|
||||
/// gRPC router implementation for SGLang
|
||||
#[allow(dead_code)] // Fields will be used once implementation is complete
|
||||
pub struct GrpcRouter {
|
||||
/// Worker connections
|
||||
workers: Arc<RwLock<Vec<Box<dyn Worker>>>>,
|
||||
/// gRPC clients for each worker
|
||||
grpc_clients: Arc<RwLock<HashMap<String, SglangSchedulerClient>>>,
|
||||
/// Load balancing policy
|
||||
policy: Arc<dyn LoadBalancingPolicy>,
|
||||
/// Tokenizer for handling text encoding/decoding
|
||||
tokenizer: Arc<dyn Tokenizer>,
|
||||
/// Reasoning parser factory for structured reasoning outputs
|
||||
reasoning_parser_factory: ParserFactory,
|
||||
/// Tool parser registry for function/tool calls
|
||||
tool_parser_registry: &'static ParserRegistry,
|
||||
/// Worker health checker
|
||||
_health_checker: Option<HealthChecker>,
|
||||
/// Configuration
|
||||
timeout_secs: u64,
|
||||
interval_secs: u64,
|
||||
dp_aware: bool,
|
||||
api_key: Option<String>,
|
||||
retry_config: RetryConfig,
|
||||
circuit_breaker_config: CircuitBreakerConfig,
|
||||
}
|
||||
|
||||
impl GrpcRouter {
|
||||
/// Create a new gRPC router
|
||||
pub async fn new(
|
||||
worker_urls: Vec<String>,
|
||||
policy: Arc<dyn LoadBalancingPolicy>,
|
||||
ctx: &Arc<crate::server::AppContext>,
|
||||
) -> Result<Self, String> {
|
||||
// Update metrics
|
||||
RouterMetrics::set_active_workers(worker_urls.len());
|
||||
|
||||
// Extract necessary components from context
|
||||
let tokenizer = ctx
|
||||
.tokenizer
|
||||
.as_ref()
|
||||
.ok_or_else(|| "gRPC router requires tokenizer".to_string())?
|
||||
.clone();
|
||||
let reasoning_parser_factory = ctx
|
||||
.reasoning_parser_factory
|
||||
.as_ref()
|
||||
.ok_or_else(|| "gRPC router requires reasoning parser factory".to_string())?
|
||||
.clone();
|
||||
let tool_parser_registry = ctx
|
||||
.tool_parser_registry
|
||||
.ok_or_else(|| "gRPC router requires tool parser registry".to_string())?;
|
||||
|
||||
// Convert config CircuitBreakerConfig to core CircuitBreakerConfig
|
||||
let circuit_breaker_config = ctx.router_config.effective_circuit_breaker_config();
|
||||
let core_cb_config = CircuitBreakerConfig {
|
||||
failure_threshold: circuit_breaker_config.failure_threshold,
|
||||
success_threshold: circuit_breaker_config.success_threshold,
|
||||
timeout_duration: Duration::from_secs(circuit_breaker_config.timeout_duration_secs),
|
||||
window_duration: Duration::from_secs(circuit_breaker_config.window_duration_secs),
|
||||
};
|
||||
|
||||
// Create gRPC clients for each worker
|
||||
let mut grpc_clients = HashMap::new();
|
||||
for url in &worker_urls {
|
||||
match SglangSchedulerClient::connect(url).await {
|
||||
Ok(client) => {
|
||||
grpc_clients.insert(url.clone(), client);
|
||||
info!("Connected to gRPC worker at {}", url);
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Failed to connect to gRPC worker at {}: {}", url, e);
|
||||
// Continue with other workers
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if grpc_clients.is_empty() {
|
||||
return Err("Failed to connect to any gRPC workers".to_string());
|
||||
}
|
||||
|
||||
// Create Worker trait objects with gRPC connection mode
|
||||
let mut workers: Vec<Box<dyn Worker>> = Vec::new();
|
||||
|
||||
// Move clients from the HashMap to the workers
|
||||
for url in &worker_urls {
|
||||
if let Some(client) = grpc_clients.remove(url) {
|
||||
let worker = BasicWorker::with_connection_mode(
|
||||
url.clone(),
|
||||
WorkerType::Regular,
|
||||
crate::core::ConnectionMode::Grpc { port: None },
|
||||
)
|
||||
.with_circuit_breaker_config(core_cb_config.clone())
|
||||
.with_health_config(HealthConfig {
|
||||
timeout_secs: ctx.router_config.health_check.timeout_secs,
|
||||
check_interval_secs: ctx.router_config.health_check.check_interval_secs,
|
||||
endpoint: ctx.router_config.health_check.endpoint.clone(),
|
||||
failure_threshold: ctx.router_config.health_check.failure_threshold,
|
||||
success_threshold: ctx.router_config.health_check.success_threshold,
|
||||
})
|
||||
.with_grpc_client(client);
|
||||
|
||||
workers.push(Box::new(worker) as Box<dyn Worker>);
|
||||
} else {
|
||||
warn!("No gRPC client for worker {}, skipping", url);
|
||||
}
|
||||
}
|
||||
|
||||
// Initialize policy with workers if needed
|
||||
if let Some(cache_aware) = policy
|
||||
.as_any()
|
||||
.downcast_ref::<crate::policies::CacheAwarePolicy>()
|
||||
{
|
||||
cache_aware.init_workers(&workers);
|
||||
}
|
||||
|
||||
let workers = Arc::new(RwLock::new(workers));
|
||||
let health_checker = crate::core::start_health_checker(
|
||||
Arc::clone(&workers),
|
||||
ctx.router_config.worker_startup_check_interval_secs,
|
||||
);
|
||||
|
||||
Ok(GrpcRouter {
|
||||
workers,
|
||||
grpc_clients: Arc::new(RwLock::new(grpc_clients)),
|
||||
policy,
|
||||
tokenizer,
|
||||
reasoning_parser_factory,
|
||||
tool_parser_registry,
|
||||
_health_checker: Some(health_checker),
|
||||
timeout_secs: ctx.router_config.worker_startup_timeout_secs,
|
||||
interval_secs: ctx.router_config.worker_startup_check_interval_secs,
|
||||
dp_aware: ctx.router_config.dp_aware,
|
||||
api_key: ctx.router_config.api_key.clone(),
|
||||
retry_config: ctx.router_config.effective_retry_config(),
|
||||
circuit_breaker_config: core_cb_config,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for GrpcRouter {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("GrpcRouter")
|
||||
.field("workers_count", &self.workers.read().unwrap().len())
|
||||
.field("timeout_secs", &self.timeout_secs)
|
||||
.field("interval_secs", &self.interval_secs)
|
||||
.field("dp_aware", &self.dp_aware)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl RouterTrait for GrpcRouter {
|
||||
fn as_any(&self) -> &dyn std::any::Any {
|
||||
self
|
||||
}
|
||||
|
||||
async fn health(&self, _req: Request<Body>) -> Response {
|
||||
(StatusCode::NOT_IMPLEMENTED).into_response()
|
||||
}
|
||||
|
||||
async fn health_generate(&self, _req: Request<Body>) -> Response {
|
||||
(StatusCode::NOT_IMPLEMENTED).into_response()
|
||||
}
|
||||
|
||||
async fn get_server_info(&self, _req: Request<Body>) -> Response {
|
||||
(StatusCode::NOT_IMPLEMENTED).into_response()
|
||||
}
|
||||
|
||||
async fn get_models(&self, _req: Request<Body>) -> Response {
|
||||
(StatusCode::NOT_IMPLEMENTED).into_response()
|
||||
}
|
||||
|
||||
async fn get_model_info(&self, _req: Request<Body>) -> Response {
|
||||
(StatusCode::NOT_IMPLEMENTED).into_response()
|
||||
}
|
||||
|
||||
async fn route_generate(
|
||||
&self,
|
||||
_headers: Option<&HeaderMap>,
|
||||
_body: &crate::protocols::spec::GenerateRequest,
|
||||
) -> Response {
|
||||
(StatusCode::NOT_IMPLEMENTED).into_response()
|
||||
}
|
||||
|
||||
async fn route_chat(
|
||||
&self,
|
||||
_headers: Option<&HeaderMap>,
|
||||
_body: &crate::protocols::spec::ChatCompletionRequest,
|
||||
) -> Response {
|
||||
(StatusCode::NOT_IMPLEMENTED).into_response()
|
||||
}
|
||||
|
||||
async fn route_completion(
|
||||
&self,
|
||||
_headers: Option<&HeaderMap>,
|
||||
_body: &crate::protocols::spec::CompletionRequest,
|
||||
) -> Response {
|
||||
(StatusCode::NOT_IMPLEMENTED).into_response()
|
||||
}
|
||||
|
||||
async fn route_embeddings(&self, _headers: Option<&HeaderMap>, _body: Body) -> Response {
|
||||
(StatusCode::NOT_IMPLEMENTED).into_response()
|
||||
}
|
||||
|
||||
async fn route_rerank(&self, _headers: Option<&HeaderMap>, _body: Body) -> Response {
|
||||
(StatusCode::NOT_IMPLEMENTED).into_response()
|
||||
}
|
||||
|
||||
async fn flush_cache(&self) -> Response {
|
||||
(StatusCode::NOT_IMPLEMENTED).into_response()
|
||||
}
|
||||
|
||||
async fn get_worker_loads(&self) -> Response {
|
||||
(StatusCode::NOT_IMPLEMENTED).into_response()
|
||||
}
|
||||
|
||||
fn router_type(&self) -> &'static str {
|
||||
"grpc"
|
||||
}
|
||||
|
||||
fn readiness(&self) -> Response {
|
||||
(StatusCode::SERVICE_UNAVAILABLE).into_response()
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl WorkerManagement for GrpcRouter {
|
||||
async fn add_worker(&self, _worker_url: &str) -> Result<String, String> {
|
||||
Err("Not implemented".to_string())
|
||||
}
|
||||
|
||||
fn remove_worker(&self, _worker_url: &str) {}
|
||||
|
||||
fn get_worker_urls(&self) -> Vec<String> {
|
||||
self.workers
|
||||
.read()
|
||||
.unwrap()
|
||||
.iter()
|
||||
.map(|w| w.url().to_string())
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
53
sgl-router/src/routers/header_utils.rs
Normal file
53
sgl-router/src/routers/header_utils.rs
Normal file
@@ -0,0 +1,53 @@
|
||||
use axum::body::Body;
|
||||
use axum::extract::Request;
|
||||
use axum::http::HeaderMap;
|
||||
|
||||
/// Copy request headers to a Vec of name-value string pairs
|
||||
/// Used for forwarding headers to backend workers
|
||||
pub fn copy_request_headers(req: &Request<Body>) -> Vec<(String, String)> {
|
||||
req.headers()
|
||||
.iter()
|
||||
.filter_map(|(name, value)| {
|
||||
// Convert header value to string, skipping non-UTF8 headers
|
||||
value
|
||||
.to_str()
|
||||
.ok()
|
||||
.map(|v| (name.to_string(), v.to_string()))
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Convert headers from reqwest Response to axum HeaderMap
|
||||
/// Filters out hop-by-hop headers that shouldn't be forwarded
|
||||
pub fn preserve_response_headers(reqwest_headers: &HeaderMap) -> HeaderMap {
|
||||
let mut headers = HeaderMap::new();
|
||||
|
||||
for (name, value) in reqwest_headers.iter() {
|
||||
// Skip hop-by-hop headers that shouldn't be forwarded
|
||||
let name_str = name.as_str().to_lowercase();
|
||||
if should_forward_header(&name_str) {
|
||||
// The original name and value are already valid, so we can just clone them
|
||||
headers.insert(name.clone(), value.clone());
|
||||
}
|
||||
}
|
||||
|
||||
headers
|
||||
}
|
||||
|
||||
/// Determine if a header should be forwarded from backend to client
|
||||
fn should_forward_header(name: &str) -> bool {
|
||||
// List of headers that should NOT be forwarded (hop-by-hop headers)
|
||||
!matches!(
|
||||
name,
|
||||
"connection" |
|
||||
"keep-alive" |
|
||||
"proxy-authenticate" |
|
||||
"proxy-authorization" |
|
||||
"te" |
|
||||
"trailers" |
|
||||
"transfer-encoding" |
|
||||
"upgrade" |
|
||||
"content-encoding" | // Let axum/hyper handle encoding
|
||||
"host" // Should not forward the backend's host header
|
||||
)
|
||||
}
|
||||
6
sgl-router/src/routers/http/mod.rs
Normal file
6
sgl-router/src/routers/http/mod.rs
Normal file
@@ -0,0 +1,6 @@
|
||||
//! HTTP router implementations
|
||||
|
||||
pub mod openai_router;
|
||||
pub mod pd_router;
|
||||
pub mod pd_types;
|
||||
pub mod router;
|
||||
379
sgl-router/src/routers/http/openai_router.rs
Normal file
379
sgl-router/src/routers/http/openai_router.rs
Normal file
@@ -0,0 +1,379 @@
|
||||
//! OpenAI router implementation (reqwest-based)
|
||||
|
||||
use crate::config::CircuitBreakerConfig;
|
||||
use crate::core::{CircuitBreaker, CircuitBreakerConfig as CoreCircuitBreakerConfig};
|
||||
use crate::protocols::spec::{ChatCompletionRequest, CompletionRequest, GenerateRequest};
|
||||
use async_trait::async_trait;
|
||||
use axum::{
|
||||
body::Body,
|
||||
extract::Request,
|
||||
http::{header::CONTENT_TYPE, HeaderMap, HeaderValue, StatusCode},
|
||||
response::{IntoResponse, Response},
|
||||
};
|
||||
use futures_util::StreamExt;
|
||||
use std::{
|
||||
any::Any,
|
||||
sync::atomic::{AtomicBool, Ordering},
|
||||
};
|
||||
|
||||
/// Router for OpenAI backend
|
||||
#[derive(Debug)]
|
||||
pub struct OpenAIRouter {
|
||||
/// HTTP client for upstream OpenAI-compatible API
|
||||
client: reqwest::Client,
|
||||
/// Base URL for identification (no trailing slash)
|
||||
base_url: String,
|
||||
/// Circuit breaker
|
||||
circuit_breaker: CircuitBreaker,
|
||||
/// Health status
|
||||
healthy: AtomicBool,
|
||||
}
|
||||
|
||||
impl OpenAIRouter {
|
||||
/// Create a new OpenAI router
|
||||
pub async fn new(
|
||||
base_url: String,
|
||||
circuit_breaker_config: Option<CircuitBreakerConfig>,
|
||||
) -> Result<Self, String> {
|
||||
let client = reqwest::Client::builder()
|
||||
.timeout(std::time::Duration::from_secs(300))
|
||||
.build()
|
||||
.map_err(|e| format!("Failed to create HTTP client: {}", e))?;
|
||||
|
||||
let base_url = base_url.trim_end_matches('/').to_string();
|
||||
|
||||
// Convert circuit breaker config
|
||||
let core_cb_config = circuit_breaker_config
|
||||
.map(|cb| CoreCircuitBreakerConfig {
|
||||
failure_threshold: cb.failure_threshold,
|
||||
success_threshold: cb.success_threshold,
|
||||
timeout_duration: std::time::Duration::from_secs(cb.timeout_duration_secs),
|
||||
window_duration: std::time::Duration::from_secs(cb.window_duration_secs),
|
||||
})
|
||||
.unwrap_or_default();
|
||||
|
||||
let circuit_breaker = CircuitBreaker::with_config(core_cb_config);
|
||||
|
||||
Ok(Self {
|
||||
client,
|
||||
base_url,
|
||||
circuit_breaker,
|
||||
healthy: AtomicBool::new(true),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl super::super::WorkerManagement for OpenAIRouter {
|
||||
async fn add_worker(&self, _worker_url: &str) -> Result<String, String> {
|
||||
Err("Cannot add workers to OpenAI router".to_string())
|
||||
}
|
||||
|
||||
fn remove_worker(&self, _worker_url: &str) {
|
||||
// No-op for OpenAI router
|
||||
}
|
||||
|
||||
fn get_worker_urls(&self) -> Vec<String> {
|
||||
vec![self.base_url.clone()]
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl super::super::RouterTrait for OpenAIRouter {
|
||||
fn as_any(&self) -> &dyn Any {
|
||||
self
|
||||
}
|
||||
|
||||
async fn health(&self, _req: Request<Body>) -> Response {
|
||||
// Simple upstream probe: GET {base}/v1/models without auth
|
||||
let url = format!("{}/v1/models", self.base_url);
|
||||
match self
|
||||
.client
|
||||
.get(&url)
|
||||
.timeout(std::time::Duration::from_secs(2))
|
||||
.send()
|
||||
.await
|
||||
{
|
||||
Ok(resp) => {
|
||||
let code = resp.status();
|
||||
// Treat success and auth-required as healthy (endpoint reachable)
|
||||
if code.is_success() || code.as_u16() == 401 || code.as_u16() == 403 {
|
||||
(StatusCode::OK, "OK").into_response()
|
||||
} else {
|
||||
(
|
||||
StatusCode::SERVICE_UNAVAILABLE,
|
||||
format!("Upstream status: {}", code),
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
}
|
||||
Err(e) => (
|
||||
StatusCode::SERVICE_UNAVAILABLE,
|
||||
format!("Upstream error: {}", e),
|
||||
)
|
||||
.into_response(),
|
||||
}
|
||||
}
|
||||
|
||||
async fn health_generate(&self, _req: Request<Body>) -> Response {
|
||||
// For OpenAI, health_generate is the same as health
|
||||
self.health(_req).await
|
||||
}
|
||||
|
||||
async fn get_server_info(&self, _req: Request<Body>) -> Response {
|
||||
let info = serde_json::json!({
|
||||
"router_type": "openai",
|
||||
"workers": 1,
|
||||
"base_url": &self.base_url
|
||||
});
|
||||
(StatusCode::OK, info.to_string()).into_response()
|
||||
}
|
||||
|
||||
async fn get_models(&self, req: Request<Body>) -> Response {
|
||||
// Proxy to upstream /v1/models; forward Authorization header if provided
|
||||
let headers = req.headers();
|
||||
|
||||
let mut upstream = self.client.get(format!("{}/v1/models", self.base_url));
|
||||
|
||||
if let Some(auth) = headers
|
||||
.get("authorization")
|
||||
.or_else(|| headers.get("Authorization"))
|
||||
{
|
||||
upstream = upstream.header("Authorization", auth);
|
||||
}
|
||||
|
||||
match upstream.send().await {
|
||||
Ok(res) => {
|
||||
let status = StatusCode::from_u16(res.status().as_u16())
|
||||
.unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
|
||||
let content_type = res.headers().get(CONTENT_TYPE).cloned();
|
||||
match res.bytes().await {
|
||||
Ok(body) => {
|
||||
let mut response = Response::new(axum::body::Body::from(body));
|
||||
*response.status_mut() = status;
|
||||
if let Some(ct) = content_type {
|
||||
response.headers_mut().insert(CONTENT_TYPE, ct);
|
||||
}
|
||||
response
|
||||
}
|
||||
Err(e) => (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Failed to read upstream response: {}", e),
|
||||
)
|
||||
.into_response(),
|
||||
}
|
||||
}
|
||||
Err(e) => (
|
||||
StatusCode::BAD_GATEWAY,
|
||||
format!("Failed to contact upstream: {}", e),
|
||||
)
|
||||
.into_response(),
|
||||
}
|
||||
}
|
||||
|
||||
async fn get_model_info(&self, _req: Request<Body>) -> Response {
|
||||
// Not directly supported without model param; return 501
|
||||
(
|
||||
StatusCode::NOT_IMPLEMENTED,
|
||||
"get_model_info not implemented for OpenAI router",
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
|
||||
async fn route_generate(
|
||||
&self,
|
||||
_headers: Option<&HeaderMap>,
|
||||
_body: &GenerateRequest,
|
||||
) -> Response {
|
||||
// Generate endpoint is SGLang-specific, not supported for OpenAI backend
|
||||
(
|
||||
StatusCode::NOT_IMPLEMENTED,
|
||||
"Generate endpoint not supported for OpenAI backend",
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
|
||||
async fn route_chat(
|
||||
&self,
|
||||
headers: Option<&HeaderMap>,
|
||||
body: &ChatCompletionRequest,
|
||||
) -> Response {
|
||||
if !self.circuit_breaker.can_execute() {
|
||||
return (StatusCode::SERVICE_UNAVAILABLE, "Circuit breaker open").into_response();
|
||||
}
|
||||
|
||||
// Serialize request body, removing SGLang-only fields
|
||||
let mut payload = match serde_json::to_value(body) {
|
||||
Ok(v) => v,
|
||||
Err(e) => {
|
||||
return (
|
||||
StatusCode::BAD_REQUEST,
|
||||
format!("Failed to serialize request: {}", e),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
};
|
||||
if let Some(obj) = payload.as_object_mut() {
|
||||
for key in [
|
||||
"top_k",
|
||||
"min_p",
|
||||
"min_tokens",
|
||||
"regex",
|
||||
"ebnf",
|
||||
"stop_token_ids",
|
||||
"no_stop_trim",
|
||||
"ignore_eos",
|
||||
"continue_final_message",
|
||||
"skip_special_tokens",
|
||||
"lora_path",
|
||||
"session_params",
|
||||
"separate_reasoning",
|
||||
"stream_reasoning",
|
||||
"chat_template_kwargs",
|
||||
"return_hidden_states",
|
||||
"repetition_penalty",
|
||||
] {
|
||||
obj.remove(key);
|
||||
}
|
||||
}
|
||||
|
||||
let url = format!("{}/v1/chat/completions", self.base_url);
|
||||
let mut req = self.client.post(&url).json(&payload);
|
||||
|
||||
// Forward Authorization header if provided
|
||||
if let Some(h) = headers {
|
||||
if let Some(auth) = h.get("authorization").or_else(|| h.get("Authorization")) {
|
||||
req = req.header("Authorization", auth);
|
||||
}
|
||||
}
|
||||
|
||||
// Accept SSE when stream=true
|
||||
if body.stream {
|
||||
req = req.header("Accept", "text/event-stream");
|
||||
}
|
||||
|
||||
let resp = match req.send().await {
|
||||
Ok(r) => r,
|
||||
Err(e) => {
|
||||
self.circuit_breaker.record_failure();
|
||||
return (
|
||||
StatusCode::SERVICE_UNAVAILABLE,
|
||||
format!("Failed to contact upstream: {}", e),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
};
|
||||
|
||||
let status = StatusCode::from_u16(resp.status().as_u16())
|
||||
.unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
|
||||
|
||||
if !body.stream {
|
||||
// Capture Content-Type before consuming response body
|
||||
let content_type = resp.headers().get(CONTENT_TYPE).cloned();
|
||||
match resp.bytes().await {
|
||||
Ok(body) => {
|
||||
self.circuit_breaker.record_success();
|
||||
let mut response = Response::new(axum::body::Body::from(body));
|
||||
*response.status_mut() = status;
|
||||
if let Some(ct) = content_type {
|
||||
response.headers_mut().insert(CONTENT_TYPE, ct);
|
||||
}
|
||||
response
|
||||
}
|
||||
Err(e) => {
|
||||
self.circuit_breaker.record_failure();
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Failed to read response: {}", e),
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Stream SSE bytes to client
|
||||
let stream = resp.bytes_stream();
|
||||
let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
|
||||
tokio::spawn(async move {
|
||||
let mut s = stream;
|
||||
while let Some(chunk) = s.next().await {
|
||||
match chunk {
|
||||
Ok(bytes) => {
|
||||
if tx.send(Ok(bytes)).is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
let _ = tx.send(Err(format!("Stream error: {}", e)));
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
let mut response = Response::new(Body::from_stream(
|
||||
tokio_stream::wrappers::UnboundedReceiverStream::new(rx),
|
||||
));
|
||||
*response.status_mut() = status;
|
||||
response
|
||||
.headers_mut()
|
||||
.insert(CONTENT_TYPE, HeaderValue::from_static("text/event-stream"));
|
||||
response
|
||||
}
|
||||
}
|
||||
|
||||
async fn route_completion(
|
||||
&self,
|
||||
_headers: Option<&HeaderMap>,
|
||||
_body: &CompletionRequest,
|
||||
) -> Response {
|
||||
// Completion endpoint not implemented for OpenAI backend
|
||||
(
|
||||
StatusCode::NOT_IMPLEMENTED,
|
||||
"Completion endpoint not implemented for OpenAI backend",
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
|
||||
async fn flush_cache(&self) -> Response {
|
||||
(
|
||||
StatusCode::NOT_IMPLEMENTED,
|
||||
"flush_cache not supported for OpenAI router",
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
|
||||
async fn get_worker_loads(&self) -> Response {
|
||||
(
|
||||
StatusCode::NOT_IMPLEMENTED,
|
||||
"get_worker_loads not supported for OpenAI router",
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
|
||||
fn router_type(&self) -> &'static str {
|
||||
"openai"
|
||||
}
|
||||
|
||||
fn readiness(&self) -> Response {
|
||||
if self.healthy.load(Ordering::Acquire) && self.circuit_breaker.can_execute() {
|
||||
(StatusCode::OK, "Ready").into_response()
|
||||
} else {
|
||||
(StatusCode::SERVICE_UNAVAILABLE, "Not ready").into_response()
|
||||
}
|
||||
}
|
||||
|
||||
async fn route_embeddings(&self, _headers: Option<&HeaderMap>, _body: Body) -> Response {
|
||||
(
|
||||
StatusCode::NOT_IMPLEMENTED,
|
||||
"Embeddings endpoint not implemented for OpenAI backend",
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
|
||||
async fn route_rerank(&self, _headers: Option<&HeaderMap>, _body: Body) -> Response {
|
||||
(
|
||||
StatusCode::NOT_IMPLEMENTED,
|
||||
"Rerank endpoint not implemented for OpenAI backend",
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
}
|
||||
2500
sgl-router/src/routers/http/pd_router.rs
Normal file
2500
sgl-router/src/routers/http/pd_router.rs
Normal file
File diff suppressed because it is too large
Load Diff
81
sgl-router/src/routers/http/pd_types.rs
Normal file
81
sgl-router/src/routers/http/pd_types.rs
Normal file
@@ -0,0 +1,81 @@
|
||||
// Custom error type for PD router operations
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum PDRouterError {
|
||||
#[error("Worker already exists: {url}")]
|
||||
WorkerAlreadyExists { url: String },
|
||||
|
||||
#[error("Worker not found: {url}")]
|
||||
WorkerNotFound { url: String },
|
||||
|
||||
#[error("Lock acquisition failed: {operation}")]
|
||||
LockError { operation: String },
|
||||
|
||||
#[error("Health check failed for worker: {url}")]
|
||||
HealthCheckFailed { url: String },
|
||||
|
||||
#[error("Invalid worker configuration: {reason}")]
|
||||
InvalidConfiguration { reason: String },
|
||||
|
||||
#[error("Network error: {message}")]
|
||||
NetworkError { message: String },
|
||||
|
||||
#[error("Timeout waiting for worker: {url}")]
|
||||
Timeout { url: String },
|
||||
}
|
||||
|
||||
// Helper functions for workers
|
||||
pub fn api_path(url: &str, api_path: &str) -> String {
|
||||
if api_path.starts_with("/") {
|
||||
format!("{}{}", url, api_path)
|
||||
} else {
|
||||
format!("{}/{}", url, api_path)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_hostname(url: &str) -> String {
|
||||
// Simple hostname extraction without external dependencies
|
||||
let url = url
|
||||
.trim_start_matches("http://")
|
||||
.trim_start_matches("https://");
|
||||
url.split(':').next().unwrap_or("localhost").to_string()
|
||||
}
|
||||
|
||||
use serde::Serialize;
|
||||
|
||||
// Optimized bootstrap wrapper for single requests
|
||||
#[derive(Serialize)]
|
||||
pub struct RequestWithBootstrap<'a, T: Serialize> {
|
||||
#[serde(flatten)]
|
||||
pub original: &'a T,
|
||||
pub bootstrap_host: String,
|
||||
pub bootstrap_port: Option<u16>,
|
||||
pub bootstrap_room: u64,
|
||||
}
|
||||
|
||||
// Optimized bootstrap wrapper for batch requests
|
||||
#[derive(Serialize)]
|
||||
pub struct BatchRequestWithBootstrap<'a, T: Serialize> {
|
||||
#[serde(flatten)]
|
||||
pub original: &'a T,
|
||||
pub bootstrap_host: Vec<String>,
|
||||
pub bootstrap_port: Vec<Option<u16>>,
|
||||
pub bootstrap_room: Vec<u64>,
|
||||
}
|
||||
|
||||
// Helper to generate bootstrap room ID
|
||||
pub fn generate_room_id() -> u64 {
|
||||
// Generate a value in the range [0, 2^63 - 1] to match Python's random.randint(0, 2**63 - 1)
|
||||
rand::random::<u64>() & (i64::MAX as u64)
|
||||
}
|
||||
|
||||
// PD-specific routing policies
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub enum PDSelectionPolicy {
|
||||
Random,
|
||||
PowerOfTwo,
|
||||
CacheAware {
|
||||
cache_threshold: f32,
|
||||
balance_abs_threshold: usize,
|
||||
balance_rel_threshold: f32,
|
||||
},
|
||||
}
|
||||
1387
sgl-router/src/routers/http/router.rs
Normal file
1387
sgl-router/src/routers/http/router.rs
Normal file
File diff suppressed because it is too large
Load Diff
107
sgl-router/src/routers/mod.rs
Normal file
107
sgl-router/src/routers/mod.rs
Normal file
@@ -0,0 +1,107 @@
|
||||
//! Router implementations
|
||||
|
||||
use async_trait::async_trait;
|
||||
use axum::{
|
||||
body::Body,
|
||||
extract::Request,
|
||||
http::{HeaderMap, StatusCode},
|
||||
response::{IntoResponse, Response},
|
||||
};
|
||||
use std::fmt::Debug;
|
||||
|
||||
use crate::protocols::spec::{ChatCompletionRequest, CompletionRequest, GenerateRequest};
|
||||
|
||||
pub mod factory;
|
||||
pub mod grpc;
|
||||
pub mod header_utils;
|
||||
pub mod http;
|
||||
|
||||
pub use factory::RouterFactory;
|
||||
// Re-export HTTP routers for convenience (keeps routers::openai_router path working)
|
||||
pub use http::{openai_router, pd_router, pd_types, router};
|
||||
|
||||
/// Worker management trait for administrative operations
|
||||
///
|
||||
/// This trait is separate from RouterTrait to allow Send futures
|
||||
/// for use in service discovery and other background tasks
|
||||
#[async_trait]
|
||||
pub trait WorkerManagement: Send + Sync {
|
||||
/// Add a worker to the router
|
||||
async fn add_worker(&self, worker_url: &str) -> Result<String, String>;
|
||||
|
||||
/// Remove a worker from the router
|
||||
fn remove_worker(&self, worker_url: &str);
|
||||
|
||||
/// Get all worker URLs
|
||||
fn get_worker_urls(&self) -> Vec<String>;
|
||||
}
|
||||
|
||||
/// Core trait for all router implementations
|
||||
///
|
||||
/// This trait provides a unified interface for routing requests,
|
||||
/// regardless of whether it's a regular router or PD router.
|
||||
#[async_trait]
|
||||
pub trait RouterTrait: Send + Sync + Debug + WorkerManagement {
|
||||
/// Get a reference to self as Any for downcasting
|
||||
fn as_any(&self) -> &dyn std::any::Any;
|
||||
|
||||
/// Route a health check request
|
||||
async fn health(&self, req: Request<Body>) -> Response;
|
||||
|
||||
/// Route a health generate request
|
||||
async fn health_generate(&self, req: Request<Body>) -> Response;
|
||||
|
||||
/// Get server information
|
||||
async fn get_server_info(&self, req: Request<Body>) -> Response;
|
||||
|
||||
/// Get available models
|
||||
async fn get_models(&self, req: Request<Body>) -> Response;
|
||||
|
||||
/// Get model information
|
||||
async fn get_model_info(&self, req: Request<Body>) -> Response;
|
||||
|
||||
/// Route a generate request
|
||||
async fn route_generate(&self, headers: Option<&HeaderMap>, body: &GenerateRequest)
|
||||
-> Response;
|
||||
|
||||
/// Route a chat completion request
|
||||
async fn route_chat(
|
||||
&self,
|
||||
headers: Option<&HeaderMap>,
|
||||
body: &ChatCompletionRequest,
|
||||
) -> Response;
|
||||
|
||||
/// Route a completion request
|
||||
async fn route_completion(
|
||||
&self,
|
||||
headers: Option<&HeaderMap>,
|
||||
body: &CompletionRequest,
|
||||
) -> Response;
|
||||
|
||||
async fn route_embeddings(&self, headers: Option<&HeaderMap>, body: Body) -> Response;
|
||||
|
||||
async fn route_rerank(&self, headers: Option<&HeaderMap>, body: Body) -> Response;
|
||||
|
||||
/// Flush cache on all workers
|
||||
async fn flush_cache(&self) -> Response;
|
||||
|
||||
/// Get worker loads (for monitoring)
|
||||
async fn get_worker_loads(&self) -> Response;
|
||||
|
||||
/// Get router type name
|
||||
fn router_type(&self) -> &'static str;
|
||||
|
||||
/// Check if this is a PD router
|
||||
fn is_pd_mode(&self) -> bool {
|
||||
self.router_type() == "pd"
|
||||
}
|
||||
|
||||
/// Server liveness check - is the server process running
|
||||
fn liveness(&self) -> Response {
|
||||
// Simple liveness check - if we can respond, we're alive
|
||||
(StatusCode::OK, "OK").into_response()
|
||||
}
|
||||
|
||||
/// Server readiness check - is the server ready to handle requests
|
||||
fn readiness(&self) -> Response;
|
||||
}
|
||||
Reference in New Issue
Block a user