[router] grpc router bootstraps (#9759)
This commit is contained in:
@@ -7,7 +7,9 @@ use sglang_router_rs::protocols::spec::{
|
|||||||
ChatCompletionRequest, ChatMessage, CompletionRequest, GenerateParameters, GenerateRequest,
|
ChatCompletionRequest, ChatMessage, CompletionRequest, GenerateParameters, GenerateRequest,
|
||||||
SamplingParams, StringOrArray, UserMessageContent,
|
SamplingParams, StringOrArray, UserMessageContent,
|
||||||
};
|
};
|
||||||
use sglang_router_rs::routers::pd_types::{generate_room_id, get_hostname, RequestWithBootstrap};
|
use sglang_router_rs::routers::http::pd_types::{
|
||||||
|
generate_room_id, get_hostname, RequestWithBootstrap,
|
||||||
|
};
|
||||||
|
|
||||||
fn create_test_worker() -> BasicWorker {
|
fn create_test_worker() -> BasicWorker {
|
||||||
BasicWorker::new(
|
BasicWorker::new(
|
||||||
|
|||||||
@@ -19,6 +19,6 @@ pub use circuit_breaker::{
|
|||||||
pub use error::{WorkerError, WorkerResult};
|
pub use error::{WorkerError, WorkerResult};
|
||||||
pub use retry::{is_retryable_status, BackoffCalculator, RetryError, RetryExecutor};
|
pub use retry::{is_retryable_status, BackoffCalculator, RetryError, RetryExecutor};
|
||||||
pub use worker::{
|
pub use worker::{
|
||||||
start_health_checker, BasicWorker, DPAwareWorker, HealthChecker, HealthConfig, Worker,
|
start_health_checker, BasicWorker, ConnectionMode, DPAwareWorker, HealthChecker, HealthConfig,
|
||||||
WorkerCollection, WorkerFactory, WorkerLoadGuard, WorkerType,
|
Worker, WorkerCollection, WorkerFactory, WorkerLoadGuard, WorkerType,
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -24,6 +24,9 @@ pub trait Worker: Send + Sync + fmt::Debug {
|
|||||||
/// Get the worker's type (Regular, Prefill, or Decode)
|
/// Get the worker's type (Regular, Prefill, or Decode)
|
||||||
fn worker_type(&self) -> WorkerType;
|
fn worker_type(&self) -> WorkerType;
|
||||||
|
|
||||||
|
/// Get the worker's connection mode (HTTP or gRPC)
|
||||||
|
fn connection_mode(&self) -> ConnectionMode;
|
||||||
|
|
||||||
/// Check if the worker is currently healthy
|
/// Check if the worker is currently healthy
|
||||||
fn is_healthy(&self) -> bool;
|
fn is_healthy(&self) -> bool;
|
||||||
|
|
||||||
@@ -152,6 +155,30 @@ pub trait Worker: Send + Sync + fmt::Debug {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Connection mode for worker communication
|
||||||
|
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
|
||||||
|
pub enum ConnectionMode {
|
||||||
|
/// HTTP/REST connection
|
||||||
|
Http,
|
||||||
|
/// gRPC connection
|
||||||
|
Grpc {
|
||||||
|
/// Optional port for gRPC endpoint (if different from URL)
|
||||||
|
port: Option<u16>,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
impl fmt::Display for ConnectionMode {
|
||||||
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||||
|
match self {
|
||||||
|
ConnectionMode::Http => write!(f, "HTTP"),
|
||||||
|
ConnectionMode::Grpc { port } => match port {
|
||||||
|
Some(p) => write!(f, "gRPC(port:{})", p),
|
||||||
|
None => write!(f, "gRPC"),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Worker type classification
|
/// Worker type classification
|
||||||
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
|
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
|
||||||
pub enum WorkerType {
|
pub enum WorkerType {
|
||||||
@@ -213,6 +240,8 @@ pub struct WorkerMetadata {
|
|||||||
pub url: String,
|
pub url: String,
|
||||||
/// Worker type
|
/// Worker type
|
||||||
pub worker_type: WorkerType,
|
pub worker_type: WorkerType,
|
||||||
|
/// Connection mode
|
||||||
|
pub connection_mode: ConnectionMode,
|
||||||
/// Additional labels/tags
|
/// Additional labels/tags
|
||||||
pub labels: std::collections::HashMap<String, String>,
|
pub labels: std::collections::HashMap<String, String>,
|
||||||
/// Health check configuration
|
/// Health check configuration
|
||||||
@@ -233,9 +262,18 @@ pub struct BasicWorker {
|
|||||||
|
|
||||||
impl BasicWorker {
|
impl BasicWorker {
|
||||||
pub fn new(url: String, worker_type: WorkerType) -> Self {
|
pub fn new(url: String, worker_type: WorkerType) -> Self {
|
||||||
|
Self::with_connection_mode(url, worker_type, ConnectionMode::Http)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn with_connection_mode(
|
||||||
|
url: String,
|
||||||
|
worker_type: WorkerType,
|
||||||
|
connection_mode: ConnectionMode,
|
||||||
|
) -> Self {
|
||||||
let metadata = WorkerMetadata {
|
let metadata = WorkerMetadata {
|
||||||
url: url.clone(),
|
url: url.clone(),
|
||||||
worker_type,
|
worker_type,
|
||||||
|
connection_mode,
|
||||||
labels: std::collections::HashMap::new(),
|
labels: std::collections::HashMap::new(),
|
||||||
health_config: HealthConfig::default(),
|
health_config: HealthConfig::default(),
|
||||||
};
|
};
|
||||||
@@ -298,6 +336,10 @@ impl Worker for BasicWorker {
|
|||||||
self.metadata.worker_type.clone()
|
self.metadata.worker_type.clone()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn connection_mode(&self) -> ConnectionMode {
|
||||||
|
self.metadata.connection_mode.clone()
|
||||||
|
}
|
||||||
|
|
||||||
fn is_healthy(&self) -> bool {
|
fn is_healthy(&self) -> bool {
|
||||||
self.healthy.load(Ordering::Acquire)
|
self.healthy.load(Ordering::Acquire)
|
||||||
}
|
}
|
||||||
@@ -434,6 +476,10 @@ impl Worker for DPAwareWorker {
|
|||||||
self.base_worker.worker_type()
|
self.base_worker.worker_type()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn connection_mode(&self) -> ConnectionMode {
|
||||||
|
self.base_worker.connection_mode()
|
||||||
|
}
|
||||||
|
|
||||||
fn is_healthy(&self) -> bool {
|
fn is_healthy(&self) -> bool {
|
||||||
self.base_worker.is_healthy()
|
self.base_worker.is_healthy()
|
||||||
}
|
}
|
||||||
@@ -603,6 +649,28 @@ impl WorkerFactory {
|
|||||||
(regular_workers, prefill_workers, decode_workers)
|
(regular_workers, prefill_workers, decode_workers)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Create a gRPC worker
|
||||||
|
pub fn create_grpc(url: String, worker_type: WorkerType, port: Option<u16>) -> Box<dyn Worker> {
|
||||||
|
Box::new(BasicWorker::with_connection_mode(
|
||||||
|
url,
|
||||||
|
worker_type,
|
||||||
|
ConnectionMode::Grpc { port },
|
||||||
|
))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create a gRPC worker with custom circuit breaker configuration
|
||||||
|
pub fn create_grpc_with_config(
|
||||||
|
url: String,
|
||||||
|
worker_type: WorkerType,
|
||||||
|
port: Option<u16>,
|
||||||
|
circuit_breaker_config: CircuitBreakerConfig,
|
||||||
|
) -> Box<dyn Worker> {
|
||||||
|
Box::new(
|
||||||
|
BasicWorker::with_connection_mode(url, worker_type, ConnectionMode::Grpc { port })
|
||||||
|
.with_circuit_breaker_config(circuit_breaker_config),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
/// Create a DP-aware worker of specified type
|
/// Create a DP-aware worker of specified type
|
||||||
pub fn create_dp_aware(
|
pub fn create_dp_aware(
|
||||||
base_url: String,
|
base_url: String,
|
||||||
|
|||||||
@@ -1,6 +1,9 @@
|
|||||||
//! Factory for creating router instances
|
//! Factory for creating router instances
|
||||||
|
|
||||||
use super::{pd_router::PDRouter, router::Router, RouterTrait};
|
use super::{
|
||||||
|
http::{pd_router::PDRouter, router::Router},
|
||||||
|
RouterTrait,
|
||||||
|
};
|
||||||
use crate::config::{PolicyConfig, RoutingMode};
|
use crate::config::{PolicyConfig, RoutingMode};
|
||||||
use crate::policies::PolicyFactory;
|
use crate::policies::PolicyFactory;
|
||||||
use crate::server::AppContext;
|
use crate::server::AppContext;
|
||||||
@@ -17,7 +20,9 @@ impl RouterFactory {
|
|||||||
return Self::create_igw_router(ctx).await;
|
return Self::create_igw_router(ctx).await;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Default to proxy mode
|
// TODO: Add gRPC mode check here when implementing gRPC support
|
||||||
|
|
||||||
|
// Default to HTTP proxy mode
|
||||||
match &ctx.router_config.mode {
|
match &ctx.router_config.mode {
|
||||||
RoutingMode::Regular { worker_urls } => {
|
RoutingMode::Regular { worker_urls } => {
|
||||||
Self::create_regular_router(worker_urls, &ctx.router_config.policy, ctx).await
|
Self::create_regular_router(worker_urls, &ctx.router_config.policy, ctx).await
|
||||||
@@ -101,6 +106,29 @@ impl RouterFactory {
|
|||||||
Ok(Box::new(router))
|
Ok(Box::new(router))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Create a gRPC router with injected policy
|
||||||
|
pub async fn create_grpc_router(
|
||||||
|
_worker_urls: &[String],
|
||||||
|
_policy_config: &PolicyConfig,
|
||||||
|
_ctx: &Arc<AppContext>,
|
||||||
|
) -> Result<Box<dyn RouterTrait>, String> {
|
||||||
|
// For now, return an error as gRPC router is not yet implemented
|
||||||
|
Err("gRPC router is not yet implemented".to_string())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create a gRPC PD router (placeholder for now)
|
||||||
|
pub async fn create_grpc_pd_router(
|
||||||
|
_prefill_urls: &[(String, Option<u16>)],
|
||||||
|
_decode_urls: &[String],
|
||||||
|
_prefill_policy_config: Option<&PolicyConfig>,
|
||||||
|
_decode_policy_config: Option<&PolicyConfig>,
|
||||||
|
_main_policy_config: &PolicyConfig,
|
||||||
|
_ctx: &Arc<AppContext>,
|
||||||
|
) -> Result<Box<dyn RouterTrait>, String> {
|
||||||
|
// For now, return an error as gRPC PD router is not yet implemented
|
||||||
|
Err("gRPC PD router is not yet implemented".to_string())
|
||||||
|
}
|
||||||
|
|
||||||
/// Create an IGW router (placeholder for future implementation)
|
/// Create an IGW router (placeholder for future implementation)
|
||||||
async fn create_igw_router(_ctx: &Arc<AppContext>) -> Result<Box<dyn RouterTrait>, String> {
|
async fn create_igw_router(_ctx: &Arc<AppContext>) -> Result<Box<dyn RouterTrait>, String> {
|
||||||
// For now, return an error indicating IGW is not yet implemented
|
// For now, return an error indicating IGW is not yet implemented
|
||||||
|
|||||||
4
sgl-router/src/routers/grpc/mod.rs
Normal file
4
sgl-router/src/routers/grpc/mod.rs
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
//! gRPC router implementations
|
||||||
|
|
||||||
|
pub mod pd_router;
|
||||||
|
pub mod router;
|
||||||
110
sgl-router/src/routers/grpc/pd_router.rs
Normal file
110
sgl-router/src/routers/grpc/pd_router.rs
Normal file
@@ -0,0 +1,110 @@
|
|||||||
|
// PD (Prefill-Decode) gRPC Router Implementation
|
||||||
|
// TODO: Implement gRPC-based PD router for disaggregated prefill-decode systems
|
||||||
|
|
||||||
|
use crate::routers::{RouterTrait, WorkerManagement};
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use axum::{
|
||||||
|
body::Body,
|
||||||
|
extract::Request,
|
||||||
|
http::{HeaderMap, StatusCode},
|
||||||
|
response::{IntoResponse, Response},
|
||||||
|
};
|
||||||
|
|
||||||
|
/// Placeholder for gRPC PD router
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct GrpcPDRouter;
|
||||||
|
|
||||||
|
impl GrpcPDRouter {
|
||||||
|
pub async fn new() -> Result<Self, String> {
|
||||||
|
// TODO: Implement gRPC PD router initialization
|
||||||
|
Err("gRPC PD router not yet implemented".to_string())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl RouterTrait for GrpcPDRouter {
|
||||||
|
fn as_any(&self) -> &dyn std::any::Any {
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn health(&self, _req: Request<Body>) -> Response {
|
||||||
|
(StatusCode::NOT_IMPLEMENTED).into_response()
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn health_generate(&self, _req: Request<Body>) -> Response {
|
||||||
|
(StatusCode::NOT_IMPLEMENTED).into_response()
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn get_server_info(&self, _req: Request<Body>) -> Response {
|
||||||
|
(StatusCode::NOT_IMPLEMENTED).into_response()
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn get_models(&self, _req: Request<Body>) -> Response {
|
||||||
|
(StatusCode::NOT_IMPLEMENTED).into_response()
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn get_model_info(&self, _req: Request<Body>) -> Response {
|
||||||
|
(StatusCode::NOT_IMPLEMENTED).into_response()
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn route_generate(
|
||||||
|
&self,
|
||||||
|
_headers: Option<&HeaderMap>,
|
||||||
|
_body: &crate::protocols::spec::GenerateRequest,
|
||||||
|
) -> Response {
|
||||||
|
(StatusCode::NOT_IMPLEMENTED).into_response()
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn route_chat(
|
||||||
|
&self,
|
||||||
|
_headers: Option<&HeaderMap>,
|
||||||
|
_body: &crate::protocols::spec::ChatCompletionRequest,
|
||||||
|
) -> Response {
|
||||||
|
(StatusCode::NOT_IMPLEMENTED).into_response()
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn route_completion(
|
||||||
|
&self,
|
||||||
|
_headers: Option<&HeaderMap>,
|
||||||
|
_body: &crate::protocols::spec::CompletionRequest,
|
||||||
|
) -> Response {
|
||||||
|
(StatusCode::NOT_IMPLEMENTED).into_response()
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn route_embeddings(&self, _headers: Option<&HeaderMap>, _body: Body) -> Response {
|
||||||
|
(StatusCode::NOT_IMPLEMENTED).into_response()
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn route_rerank(&self, _headers: Option<&HeaderMap>, _body: Body) -> Response {
|
||||||
|
(StatusCode::NOT_IMPLEMENTED).into_response()
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn flush_cache(&self) -> Response {
|
||||||
|
(StatusCode::NOT_IMPLEMENTED).into_response()
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn get_worker_loads(&self) -> Response {
|
||||||
|
(StatusCode::NOT_IMPLEMENTED).into_response()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn router_type(&self) -> &'static str {
|
||||||
|
"grpc_pd"
|
||||||
|
}
|
||||||
|
|
||||||
|
fn readiness(&self) -> Response {
|
||||||
|
(StatusCode::SERVICE_UNAVAILABLE).into_response()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl WorkerManagement for GrpcPDRouter {
|
||||||
|
async fn add_worker(&self, _worker_url: &str) -> Result<String, String> {
|
||||||
|
Err("Not implemented".to_string())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn remove_worker(&self, _worker_url: &str) {}
|
||||||
|
|
||||||
|
fn get_worker_urls(&self) -> Vec<String> {
|
||||||
|
vec![]
|
||||||
|
}
|
||||||
|
}
|
||||||
110
sgl-router/src/routers/grpc/router.rs
Normal file
110
sgl-router/src/routers/grpc/router.rs
Normal file
@@ -0,0 +1,110 @@
|
|||||||
|
// gRPC Router Implementation
|
||||||
|
// TODO: Implement gRPC-based router
|
||||||
|
|
||||||
|
use crate::routers::{RouterTrait, WorkerManagement};
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use axum::{
|
||||||
|
body::Body,
|
||||||
|
extract::Request,
|
||||||
|
http::{HeaderMap, StatusCode},
|
||||||
|
response::{IntoResponse, Response},
|
||||||
|
};
|
||||||
|
|
||||||
|
/// Placeholder for gRPC router
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct GrpcRouter;
|
||||||
|
|
||||||
|
impl GrpcRouter {
|
||||||
|
pub async fn new() -> Result<Self, String> {
|
||||||
|
// TODO: Implement gRPC router initialization
|
||||||
|
Err("gRPC router not yet implemented".to_string())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl RouterTrait for GrpcRouter {
|
||||||
|
fn as_any(&self) -> &dyn std::any::Any {
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn health(&self, _req: Request<Body>) -> Response {
|
||||||
|
(StatusCode::NOT_IMPLEMENTED).into_response()
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn health_generate(&self, _req: Request<Body>) -> Response {
|
||||||
|
(StatusCode::NOT_IMPLEMENTED).into_response()
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn get_server_info(&self, _req: Request<Body>) -> Response {
|
||||||
|
(StatusCode::NOT_IMPLEMENTED).into_response()
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn get_models(&self, _req: Request<Body>) -> Response {
|
||||||
|
(StatusCode::NOT_IMPLEMENTED).into_response()
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn get_model_info(&self, _req: Request<Body>) -> Response {
|
||||||
|
(StatusCode::NOT_IMPLEMENTED).into_response()
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn route_generate(
|
||||||
|
&self,
|
||||||
|
_headers: Option<&HeaderMap>,
|
||||||
|
_body: &crate::protocols::spec::GenerateRequest,
|
||||||
|
) -> Response {
|
||||||
|
(StatusCode::NOT_IMPLEMENTED).into_response()
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn route_chat(
|
||||||
|
&self,
|
||||||
|
_headers: Option<&HeaderMap>,
|
||||||
|
_body: &crate::protocols::spec::ChatCompletionRequest,
|
||||||
|
) -> Response {
|
||||||
|
(StatusCode::NOT_IMPLEMENTED).into_response()
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn route_completion(
|
||||||
|
&self,
|
||||||
|
_headers: Option<&HeaderMap>,
|
||||||
|
_body: &crate::protocols::spec::CompletionRequest,
|
||||||
|
) -> Response {
|
||||||
|
(StatusCode::NOT_IMPLEMENTED).into_response()
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn route_embeddings(&self, _headers: Option<&HeaderMap>, _body: Body) -> Response {
|
||||||
|
(StatusCode::NOT_IMPLEMENTED).into_response()
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn route_rerank(&self, _headers: Option<&HeaderMap>, _body: Body) -> Response {
|
||||||
|
(StatusCode::NOT_IMPLEMENTED).into_response()
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn flush_cache(&self) -> Response {
|
||||||
|
(StatusCode::NOT_IMPLEMENTED).into_response()
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn get_worker_loads(&self) -> Response {
|
||||||
|
(StatusCode::NOT_IMPLEMENTED).into_response()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn router_type(&self) -> &'static str {
|
||||||
|
"grpc"
|
||||||
|
}
|
||||||
|
|
||||||
|
fn readiness(&self) -> Response {
|
||||||
|
(StatusCode::SERVICE_UNAVAILABLE).into_response()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl WorkerManagement for GrpcRouter {
|
||||||
|
async fn add_worker(&self, _worker_url: &str) -> Result<String, String> {
|
||||||
|
Err("Not implemented".to_string())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn remove_worker(&self, _worker_url: &str) {}
|
||||||
|
|
||||||
|
fn get_worker_urls(&self) -> Vec<String> {
|
||||||
|
vec![]
|
||||||
|
}
|
||||||
|
}
|
||||||
5
sgl-router/src/routers/http/mod.rs
Normal file
5
sgl-router/src/routers/http/mod.rs
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
//! HTTP router implementations
|
||||||
|
|
||||||
|
pub mod pd_router;
|
||||||
|
pub mod pd_types;
|
||||||
|
pub mod router;
|
||||||
@@ -1,6 +1,5 @@
|
|||||||
// PD (Prefill-Decode) Router Implementation
|
// PD (Prefill-Decode) Router Implementation
|
||||||
// This module handles routing for disaggregated prefill-decode systems
|
// This module handles routing for disaggregated prefill-decode systems
|
||||||
use super::header_utils;
|
|
||||||
use super::pd_types::{api_path, PDRouterError};
|
use super::pd_types::{api_path, PDRouterError};
|
||||||
use crate::config::types::{
|
use crate::config::types::{
|
||||||
CircuitBreakerConfig as ConfigCircuitBreakerConfig,
|
CircuitBreakerConfig as ConfigCircuitBreakerConfig,
|
||||||
@@ -16,6 +15,7 @@ use crate::protocols::spec::{
|
|||||||
ChatCompletionRequest, ChatMessage, CompletionRequest, GenerateRequest, StringOrArray,
|
ChatCompletionRequest, ChatMessage, CompletionRequest, GenerateRequest, StringOrArray,
|
||||||
UserMessageContent,
|
UserMessageContent,
|
||||||
};
|
};
|
||||||
|
use crate::routers::header_utils;
|
||||||
use crate::routers::{RouterTrait, WorkerManagement};
|
use crate::routers::{RouterTrait, WorkerManagement};
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use axum::{
|
use axum::{
|
||||||
@@ -72,7 +72,7 @@ impl PDRouter {
|
|||||||
|
|
||||||
// Private helper method to perform health check on a new server
|
// Private helper method to perform health check on a new server
|
||||||
async fn wait_for_server_health(&self, url: &str) -> Result<(), PDRouterError> {
|
async fn wait_for_server_health(&self, url: &str) -> Result<(), PDRouterError> {
|
||||||
crate::routers::router::Router::wait_for_healthy_workers(
|
crate::routers::http::router::Router::wait_for_healthy_workers(
|
||||||
&[url.to_string()],
|
&[url.to_string()],
|
||||||
self.timeout_secs,
|
self.timeout_secs,
|
||||||
self.interval_secs,
|
self.interval_secs,
|
||||||
@@ -435,7 +435,7 @@ impl PDRouter {
|
|||||||
.map(|worker| worker.url().to_string())
|
.map(|worker| worker.url().to_string())
|
||||||
.collect();
|
.collect();
|
||||||
if !all_urls.is_empty() {
|
if !all_urls.is_empty() {
|
||||||
crate::routers::router::Router::wait_for_healthy_workers(
|
crate::routers::http::router::Router::wait_for_healthy_workers(
|
||||||
&all_urls,
|
&all_urls,
|
||||||
timeout_secs,
|
timeout_secs,
|
||||||
interval_secs,
|
interval_secs,
|
||||||
@@ -1935,6 +1935,14 @@ impl RouterTrait for PDRouter {
|
|||||||
self.execute_dual_dispatch(headers, body, context).await
|
self.execute_dual_dispatch(headers, body, context).await
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn route_embeddings(&self, _headers: Option<&HeaderMap>, _body: Body) -> Response {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn route_rerank(&self, _headers: Option<&HeaderMap>, _body: Body) -> Response {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
|
||||||
async fn flush_cache(&self) -> Response {
|
async fn flush_cache(&self) -> Response {
|
||||||
// Process both prefill and decode workers
|
// Process both prefill and decode workers
|
||||||
let (prefill_results, prefill_errors) = self
|
let (prefill_results, prefill_errors) = self
|
||||||
@@ -2040,7 +2048,7 @@ impl RouterTrait for PDRouter {
|
|||||||
let total_decode = self.decode_workers.read().unwrap().len();
|
let total_decode = self.decode_workers.read().unwrap().len();
|
||||||
|
|
||||||
if healthy_prefill_count > 0 && healthy_decode_count > 0 {
|
if healthy_prefill_count > 0 && healthy_decode_count > 0 {
|
||||||
Json(serde_json::json!({
|
Json(json!({
|
||||||
"status": "ready",
|
"status": "ready",
|
||||||
"prefill": {
|
"prefill": {
|
||||||
"healthy": healthy_prefill_count,
|
"healthy": healthy_prefill_count,
|
||||||
@@ -1,4 +1,3 @@
|
|||||||
use super::header_utils;
|
|
||||||
use crate::config::types::{
|
use crate::config::types::{
|
||||||
CircuitBreakerConfig as ConfigCircuitBreakerConfig,
|
CircuitBreakerConfig as ConfigCircuitBreakerConfig,
|
||||||
HealthCheckConfig as ConfigHealthCheckConfig, RetryConfig,
|
HealthCheckConfig as ConfigHealthCheckConfig, RetryConfig,
|
||||||
@@ -12,6 +11,7 @@ use crate::policies::LoadBalancingPolicy;
|
|||||||
use crate::protocols::spec::{
|
use crate::protocols::spec::{
|
||||||
ChatCompletionRequest, CompletionRequest, GenerateRequest, GenerationRequest,
|
ChatCompletionRequest, CompletionRequest, GenerateRequest, GenerationRequest,
|
||||||
};
|
};
|
||||||
|
use crate::routers::header_utils;
|
||||||
use crate::routers::{RouterTrait, WorkerManagement};
|
use crate::routers::{RouterTrait, WorkerManagement};
|
||||||
use axum::{
|
use axum::{
|
||||||
body::Body,
|
body::Body,
|
||||||
@@ -393,7 +393,7 @@ impl Router {
|
|||||||
|
|
||||||
// Helper method to proxy GET requests to the first available worker
|
// Helper method to proxy GET requests to the first available worker
|
||||||
async fn proxy_get_request(&self, req: Request<Body>, endpoint: &str) -> Response {
|
async fn proxy_get_request(&self, req: Request<Body>, endpoint: &str) -> Response {
|
||||||
let headers = super::header_utils::copy_request_headers(&req);
|
let headers = header_utils::copy_request_headers(&req);
|
||||||
|
|
||||||
match self.select_first_worker() {
|
match self.select_first_worker() {
|
||||||
Ok(worker_url) => {
|
Ok(worker_url) => {
|
||||||
@@ -667,7 +667,7 @@ impl Router {
|
|||||||
|
|
||||||
if !is_stream {
|
if !is_stream {
|
||||||
// For non-streaming requests, preserve headers
|
// For non-streaming requests, preserve headers
|
||||||
let response_headers = super::header_utils::preserve_response_headers(res.headers());
|
let response_headers = header_utils::preserve_response_headers(res.headers());
|
||||||
|
|
||||||
let response = match res.bytes().await {
|
let response = match res.bytes().await {
|
||||||
Ok(body) => {
|
Ok(body) => {
|
||||||
@@ -1198,6 +1198,14 @@ impl RouterTrait for Router {
|
|||||||
.await
|
.await
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn route_embeddings(&self, _headers: Option<&HeaderMap>, _body: Body) -> Response {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn route_rerank(&self, _headers: Option<&HeaderMap>, _body: Body) -> Response {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
|
||||||
async fn flush_cache(&self) -> Response {
|
async fn flush_cache(&self) -> Response {
|
||||||
// Get all worker URLs
|
// Get all worker URLs
|
||||||
let worker_urls = self.get_worker_urls();
|
let worker_urls = self.get_worker_urls();
|
||||||
@@ -12,10 +12,9 @@ use std::fmt::Debug;
|
|||||||
use crate::protocols::spec::{ChatCompletionRequest, CompletionRequest, GenerateRequest};
|
use crate::protocols::spec::{ChatCompletionRequest, CompletionRequest, GenerateRequest};
|
||||||
|
|
||||||
pub mod factory;
|
pub mod factory;
|
||||||
|
pub mod grpc;
|
||||||
pub mod header_utils;
|
pub mod header_utils;
|
||||||
pub mod pd_router;
|
pub mod http;
|
||||||
pub mod pd_types;
|
|
||||||
pub mod router;
|
|
||||||
|
|
||||||
pub use factory::RouterFactory;
|
pub use factory::RouterFactory;
|
||||||
|
|
||||||
@@ -77,6 +76,10 @@ pub trait RouterTrait: Send + Sync + Debug + WorkerManagement {
|
|||||||
body: &CompletionRequest,
|
body: &CompletionRequest,
|
||||||
) -> Response;
|
) -> Response;
|
||||||
|
|
||||||
|
async fn route_embeddings(&self, headers: Option<&HeaderMap>, body: Body) -> Response;
|
||||||
|
|
||||||
|
async fn route_rerank(&self, headers: Option<&HeaderMap>, body: Body) -> Response;
|
||||||
|
|
||||||
/// Flush cache on all workers
|
/// Flush cache on all workers
|
||||||
async fn flush_cache(&self) -> Response;
|
async fn flush_cache(&self) -> Response;
|
||||||
|
|
||||||
|
|||||||
@@ -383,7 +383,7 @@ async fn handle_pod_event(
|
|||||||
// Handle PD mode with specific pod types
|
// Handle PD mode with specific pod types
|
||||||
let result = if pd_mode && pod_info.pod_type.is_some() {
|
let result = if pd_mode && pod_info.pod_type.is_some() {
|
||||||
// Need to import PDRouter type
|
// Need to import PDRouter type
|
||||||
use crate::routers::pd_router::PDRouter;
|
use crate::routers::http::pd_router::PDRouter;
|
||||||
|
|
||||||
// Try to downcast to PDRouter
|
// Try to downcast to PDRouter
|
||||||
if let Some(pd_router) = router.as_any().downcast_ref::<PDRouter>() {
|
if let Some(pd_router) = router.as_any().downcast_ref::<PDRouter>() {
|
||||||
@@ -453,7 +453,7 @@ async fn handle_pod_deletion(
|
|||||||
|
|
||||||
// Handle PD mode removal
|
// Handle PD mode removal
|
||||||
if pd_mode && pod_info.pod_type.is_some() {
|
if pd_mode && pod_info.pod_type.is_some() {
|
||||||
use crate::routers::pd_router::PDRouter;
|
use crate::routers::http::pd_router::PDRouter;
|
||||||
|
|
||||||
// Try to downcast to PDRouter for PD-specific removal
|
// Try to downcast to PDRouter for PD-specific removal
|
||||||
if let Some(pd_router) = router.as_any().downcast_ref::<PDRouter>() {
|
if let Some(pd_router) = router.as_any().downcast_ref::<PDRouter>() {
|
||||||
@@ -581,7 +581,7 @@ mod tests {
|
|||||||
async fn create_test_router() -> Arc<dyn RouterTrait> {
|
async fn create_test_router() -> Arc<dyn RouterTrait> {
|
||||||
use crate::config::PolicyConfig;
|
use crate::config::PolicyConfig;
|
||||||
use crate::policies::PolicyFactory;
|
use crate::policies::PolicyFactory;
|
||||||
use crate::routers::router::Router;
|
use crate::routers::http::router::Router;
|
||||||
|
|
||||||
let policy = PolicyFactory::create_from_config(&PolicyConfig::Random);
|
let policy = PolicyFactory::create_from_config(&PolicyConfig::Random);
|
||||||
let router = Router::new(
|
let router = Router::new(
|
||||||
|
|||||||
@@ -5,8 +5,8 @@ mod test_pd_routing {
|
|||||||
CircuitBreakerConfig, PolicyConfig, RetryConfig, RouterConfig, RoutingMode,
|
CircuitBreakerConfig, PolicyConfig, RetryConfig, RouterConfig, RoutingMode,
|
||||||
};
|
};
|
||||||
use sglang_router_rs::core::{WorkerFactory, WorkerType};
|
use sglang_router_rs::core::{WorkerFactory, WorkerType};
|
||||||
use sglang_router_rs::routers::pd_types::get_hostname;
|
use sglang_router_rs::routers::http::pd_types::get_hostname;
|
||||||
use sglang_router_rs::routers::pd_types::PDSelectionPolicy;
|
use sglang_router_rs::routers::http::pd_types::PDSelectionPolicy;
|
||||||
use sglang_router_rs::routers::RouterFactory;
|
use sglang_router_rs::routers::RouterFactory;
|
||||||
|
|
||||||
// Test-only struct to help validate PD request parsing
|
// Test-only struct to help validate PD request parsing
|
||||||
|
|||||||
Reference in New Issue
Block a user