From 5343058875a7c07ad62cfef9681f26ffbe359859 Mon Sep 17 00:00:00 2001 From: Simo Lin Date: Thu, 28 Aug 2025 12:07:06 -0700 Subject: [PATCH] [router] grpc router bootstraps (#9759) --- sgl-router/benches/request_processing.rs | 4 +- sgl-router/src/core/mod.rs | 4 +- sgl-router/src/core/worker.rs | 68 +++++++++++ sgl-router/src/routers/factory.rs | 32 ++++- sgl-router/src/routers/grpc/mod.rs | 4 + sgl-router/src/routers/grpc/pd_router.rs | 110 ++++++++++++++++++ sgl-router/src/routers/grpc/router.rs | 110 ++++++++++++++++++ sgl-router/src/routers/http/mod.rs | 5 + .../src/routers/{ => http}/pd_router.rs | 16 ++- sgl-router/src/routers/{ => http}/pd_types.rs | 0 sgl-router/src/routers/{ => http}/router.rs | 14 ++- sgl-router/src/routers/mod.rs | 9 +- sgl-router/src/service_discovery.rs | 6 +- sgl-router/tests/test_pd_routing.rs | 4 +- 14 files changed, 366 insertions(+), 20 deletions(-) create mode 100644 sgl-router/src/routers/grpc/mod.rs create mode 100644 sgl-router/src/routers/grpc/pd_router.rs create mode 100644 sgl-router/src/routers/grpc/router.rs create mode 100644 sgl-router/src/routers/http/mod.rs rename sgl-router/src/routers/{ => http}/pd_router.rs (99%) rename sgl-router/src/routers/{ => http}/pd_types.rs (100%) rename sgl-router/src/routers/{ => http}/router.rs (99%) diff --git a/sgl-router/benches/request_processing.rs b/sgl-router/benches/request_processing.rs index 3edb2fc3d..efd08bf74 100644 --- a/sgl-router/benches/request_processing.rs +++ b/sgl-router/benches/request_processing.rs @@ -7,7 +7,9 @@ use sglang_router_rs::protocols::spec::{ ChatCompletionRequest, ChatMessage, CompletionRequest, GenerateParameters, GenerateRequest, SamplingParams, StringOrArray, UserMessageContent, }; -use sglang_router_rs::routers::pd_types::{generate_room_id, get_hostname, RequestWithBootstrap}; +use sglang_router_rs::routers::http::pd_types::{ + generate_room_id, get_hostname, RequestWithBootstrap, +}; fn create_test_worker() -> BasicWorker { BasicWorker::new( diff --git a/sgl-router/src/core/mod.rs b/sgl-router/src/core/mod.rs index 4ccb05fb0..b46810b4c 100644 --- a/sgl-router/src/core/mod.rs +++ b/sgl-router/src/core/mod.rs @@ -19,6 +19,6 @@ pub use circuit_breaker::{ pub use error::{WorkerError, WorkerResult}; pub use retry::{is_retryable_status, BackoffCalculator, RetryError, RetryExecutor}; pub use worker::{ - start_health_checker, BasicWorker, DPAwareWorker, HealthChecker, HealthConfig, Worker, - WorkerCollection, WorkerFactory, WorkerLoadGuard, WorkerType, + start_health_checker, BasicWorker, ConnectionMode, DPAwareWorker, HealthChecker, HealthConfig, + Worker, WorkerCollection, WorkerFactory, WorkerLoadGuard, WorkerType, }; diff --git a/sgl-router/src/core/worker.rs b/sgl-router/src/core/worker.rs index f3039ae21..b054355f0 100644 --- a/sgl-router/src/core/worker.rs +++ b/sgl-router/src/core/worker.rs @@ -24,6 +24,9 @@ pub trait Worker: Send + Sync + fmt::Debug { /// Get the worker's type (Regular, Prefill, or Decode) fn worker_type(&self) -> WorkerType; + /// Get the worker's connection mode (HTTP or gRPC) + fn connection_mode(&self) -> ConnectionMode; + /// Check if the worker is currently healthy fn is_healthy(&self) -> bool; @@ -152,6 +155,30 @@ pub trait Worker: Send + Sync + fmt::Debug { } } +/// Connection mode for worker communication +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub enum ConnectionMode { + /// HTTP/REST connection + Http, + /// gRPC connection + Grpc { + /// Optional port for gRPC endpoint (if different from URL) + port: Option, + }, +} + +impl fmt::Display for ConnectionMode { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + ConnectionMode::Http => write!(f, "HTTP"), + ConnectionMode::Grpc { port } => match port { + Some(p) => write!(f, "gRPC(port:{})", p), + None => write!(f, "gRPC"), + }, + } + } +} + /// Worker type classification #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub enum WorkerType { @@ -213,6 +240,8 @@ pub struct WorkerMetadata { pub url: String, /// Worker type pub worker_type: WorkerType, + /// Connection mode + pub connection_mode: ConnectionMode, /// Additional labels/tags pub labels: std::collections::HashMap, /// Health check configuration @@ -233,9 +262,18 @@ pub struct BasicWorker { impl BasicWorker { pub fn new(url: String, worker_type: WorkerType) -> Self { + Self::with_connection_mode(url, worker_type, ConnectionMode::Http) + } + + pub fn with_connection_mode( + url: String, + worker_type: WorkerType, + connection_mode: ConnectionMode, + ) -> Self { let metadata = WorkerMetadata { url: url.clone(), worker_type, + connection_mode, labels: std::collections::HashMap::new(), health_config: HealthConfig::default(), }; @@ -298,6 +336,10 @@ impl Worker for BasicWorker { self.metadata.worker_type.clone() } + fn connection_mode(&self) -> ConnectionMode { + self.metadata.connection_mode.clone() + } + fn is_healthy(&self) -> bool { self.healthy.load(Ordering::Acquire) } @@ -434,6 +476,10 @@ impl Worker for DPAwareWorker { self.base_worker.worker_type() } + fn connection_mode(&self) -> ConnectionMode { + self.base_worker.connection_mode() + } + fn is_healthy(&self) -> bool { self.base_worker.is_healthy() } @@ -603,6 +649,28 @@ impl WorkerFactory { (regular_workers, prefill_workers, decode_workers) } + /// Create a gRPC worker + pub fn create_grpc(url: String, worker_type: WorkerType, port: Option) -> Box { + Box::new(BasicWorker::with_connection_mode( + url, + worker_type, + ConnectionMode::Grpc { port }, + )) + } + + /// Create a gRPC worker with custom circuit breaker configuration + pub fn create_grpc_with_config( + url: String, + worker_type: WorkerType, + port: Option, + circuit_breaker_config: CircuitBreakerConfig, + ) -> Box { + Box::new( + BasicWorker::with_connection_mode(url, worker_type, ConnectionMode::Grpc { port }) + .with_circuit_breaker_config(circuit_breaker_config), + ) + } + /// Create a DP-aware worker of specified type pub fn create_dp_aware( base_url: String, diff --git a/sgl-router/src/routers/factory.rs b/sgl-router/src/routers/factory.rs index 7b4f848bc..c0a4aa6d0 100644 --- a/sgl-router/src/routers/factory.rs +++ b/sgl-router/src/routers/factory.rs @@ -1,6 +1,9 @@ //! Factory for creating router instances -use super::{pd_router::PDRouter, router::Router, RouterTrait}; +use super::{ + http::{pd_router::PDRouter, router::Router}, + RouterTrait, +}; use crate::config::{PolicyConfig, RoutingMode}; use crate::policies::PolicyFactory; use crate::server::AppContext; @@ -17,7 +20,9 @@ impl RouterFactory { return Self::create_igw_router(ctx).await; } - // Default to proxy mode + // TODO: Add gRPC mode check here when implementing gRPC support + + // Default to HTTP proxy mode match &ctx.router_config.mode { RoutingMode::Regular { worker_urls } => { Self::create_regular_router(worker_urls, &ctx.router_config.policy, ctx).await @@ -101,6 +106,29 @@ impl RouterFactory { Ok(Box::new(router)) } + /// Create a gRPC router with injected policy + pub async fn create_grpc_router( + _worker_urls: &[String], + _policy_config: &PolicyConfig, + _ctx: &Arc, + ) -> Result, String> { + // For now, return an error as gRPC router is not yet implemented + Err("gRPC router is not yet implemented".to_string()) + } + + /// Create a gRPC PD router (placeholder for now) + pub async fn create_grpc_pd_router( + _prefill_urls: &[(String, Option)], + _decode_urls: &[String], + _prefill_policy_config: Option<&PolicyConfig>, + _decode_policy_config: Option<&PolicyConfig>, + _main_policy_config: &PolicyConfig, + _ctx: &Arc, + ) -> Result, String> { + // For now, return an error as gRPC PD router is not yet implemented + Err("gRPC PD router is not yet implemented".to_string()) + } + /// Create an IGW router (placeholder for future implementation) async fn create_igw_router(_ctx: &Arc) -> Result, String> { // For now, return an error indicating IGW is not yet implemented diff --git a/sgl-router/src/routers/grpc/mod.rs b/sgl-router/src/routers/grpc/mod.rs new file mode 100644 index 000000000..a6a5d8eec --- /dev/null +++ b/sgl-router/src/routers/grpc/mod.rs @@ -0,0 +1,4 @@ +//! gRPC router implementations + +pub mod pd_router; +pub mod router; diff --git a/sgl-router/src/routers/grpc/pd_router.rs b/sgl-router/src/routers/grpc/pd_router.rs new file mode 100644 index 000000000..e3f453186 --- /dev/null +++ b/sgl-router/src/routers/grpc/pd_router.rs @@ -0,0 +1,110 @@ +// PD (Prefill-Decode) gRPC Router Implementation +// TODO: Implement gRPC-based PD router for disaggregated prefill-decode systems + +use crate::routers::{RouterTrait, WorkerManagement}; +use async_trait::async_trait; +use axum::{ + body::Body, + extract::Request, + http::{HeaderMap, StatusCode}, + response::{IntoResponse, Response}, +}; + +/// Placeholder for gRPC PD router +#[derive(Debug)] +pub struct GrpcPDRouter; + +impl GrpcPDRouter { + pub async fn new() -> Result { + // TODO: Implement gRPC PD router initialization + Err("gRPC PD router not yet implemented".to_string()) + } +} + +#[async_trait] +impl RouterTrait for GrpcPDRouter { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + async fn health(&self, _req: Request) -> Response { + (StatusCode::NOT_IMPLEMENTED).into_response() + } + + async fn health_generate(&self, _req: Request) -> Response { + (StatusCode::NOT_IMPLEMENTED).into_response() + } + + async fn get_server_info(&self, _req: Request) -> Response { + (StatusCode::NOT_IMPLEMENTED).into_response() + } + + async fn get_models(&self, _req: Request) -> Response { + (StatusCode::NOT_IMPLEMENTED).into_response() + } + + async fn get_model_info(&self, _req: Request) -> Response { + (StatusCode::NOT_IMPLEMENTED).into_response() + } + + async fn route_generate( + &self, + _headers: Option<&HeaderMap>, + _body: &crate::protocols::spec::GenerateRequest, + ) -> Response { + (StatusCode::NOT_IMPLEMENTED).into_response() + } + + async fn route_chat( + &self, + _headers: Option<&HeaderMap>, + _body: &crate::protocols::spec::ChatCompletionRequest, + ) -> Response { + (StatusCode::NOT_IMPLEMENTED).into_response() + } + + async fn route_completion( + &self, + _headers: Option<&HeaderMap>, + _body: &crate::protocols::spec::CompletionRequest, + ) -> Response { + (StatusCode::NOT_IMPLEMENTED).into_response() + } + + async fn route_embeddings(&self, _headers: Option<&HeaderMap>, _body: Body) -> Response { + (StatusCode::NOT_IMPLEMENTED).into_response() + } + + async fn route_rerank(&self, _headers: Option<&HeaderMap>, _body: Body) -> Response { + (StatusCode::NOT_IMPLEMENTED).into_response() + } + + async fn flush_cache(&self) -> Response { + (StatusCode::NOT_IMPLEMENTED).into_response() + } + + async fn get_worker_loads(&self) -> Response { + (StatusCode::NOT_IMPLEMENTED).into_response() + } + + fn router_type(&self) -> &'static str { + "grpc_pd" + } + + fn readiness(&self) -> Response { + (StatusCode::SERVICE_UNAVAILABLE).into_response() + } +} + +#[async_trait] +impl WorkerManagement for GrpcPDRouter { + async fn add_worker(&self, _worker_url: &str) -> Result { + Err("Not implemented".to_string()) + } + + fn remove_worker(&self, _worker_url: &str) {} + + fn get_worker_urls(&self) -> Vec { + vec![] + } +} diff --git a/sgl-router/src/routers/grpc/router.rs b/sgl-router/src/routers/grpc/router.rs new file mode 100644 index 000000000..f5fc407f7 --- /dev/null +++ b/sgl-router/src/routers/grpc/router.rs @@ -0,0 +1,110 @@ +// gRPC Router Implementation +// TODO: Implement gRPC-based router + +use crate::routers::{RouterTrait, WorkerManagement}; +use async_trait::async_trait; +use axum::{ + body::Body, + extract::Request, + http::{HeaderMap, StatusCode}, + response::{IntoResponse, Response}, +}; + +/// Placeholder for gRPC router +#[derive(Debug)] +pub struct GrpcRouter; + +impl GrpcRouter { + pub async fn new() -> Result { + // TODO: Implement gRPC router initialization + Err("gRPC router not yet implemented".to_string()) + } +} + +#[async_trait] +impl RouterTrait for GrpcRouter { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + async fn health(&self, _req: Request) -> Response { + (StatusCode::NOT_IMPLEMENTED).into_response() + } + + async fn health_generate(&self, _req: Request) -> Response { + (StatusCode::NOT_IMPLEMENTED).into_response() + } + + async fn get_server_info(&self, _req: Request) -> Response { + (StatusCode::NOT_IMPLEMENTED).into_response() + } + + async fn get_models(&self, _req: Request) -> Response { + (StatusCode::NOT_IMPLEMENTED).into_response() + } + + async fn get_model_info(&self, _req: Request) -> Response { + (StatusCode::NOT_IMPLEMENTED).into_response() + } + + async fn route_generate( + &self, + _headers: Option<&HeaderMap>, + _body: &crate::protocols::spec::GenerateRequest, + ) -> Response { + (StatusCode::NOT_IMPLEMENTED).into_response() + } + + async fn route_chat( + &self, + _headers: Option<&HeaderMap>, + _body: &crate::protocols::spec::ChatCompletionRequest, + ) -> Response { + (StatusCode::NOT_IMPLEMENTED).into_response() + } + + async fn route_completion( + &self, + _headers: Option<&HeaderMap>, + _body: &crate::protocols::spec::CompletionRequest, + ) -> Response { + (StatusCode::NOT_IMPLEMENTED).into_response() + } + + async fn route_embeddings(&self, _headers: Option<&HeaderMap>, _body: Body) -> Response { + (StatusCode::NOT_IMPLEMENTED).into_response() + } + + async fn route_rerank(&self, _headers: Option<&HeaderMap>, _body: Body) -> Response { + (StatusCode::NOT_IMPLEMENTED).into_response() + } + + async fn flush_cache(&self) -> Response { + (StatusCode::NOT_IMPLEMENTED).into_response() + } + + async fn get_worker_loads(&self) -> Response { + (StatusCode::NOT_IMPLEMENTED).into_response() + } + + fn router_type(&self) -> &'static str { + "grpc" + } + + fn readiness(&self) -> Response { + (StatusCode::SERVICE_UNAVAILABLE).into_response() + } +} + +#[async_trait] +impl WorkerManagement for GrpcRouter { + async fn add_worker(&self, _worker_url: &str) -> Result { + Err("Not implemented".to_string()) + } + + fn remove_worker(&self, _worker_url: &str) {} + + fn get_worker_urls(&self) -> Vec { + vec![] + } +} diff --git a/sgl-router/src/routers/http/mod.rs b/sgl-router/src/routers/http/mod.rs new file mode 100644 index 000000000..3f31b6f86 --- /dev/null +++ b/sgl-router/src/routers/http/mod.rs @@ -0,0 +1,5 @@ +//! HTTP router implementations + +pub mod pd_router; +pub mod pd_types; +pub mod router; diff --git a/sgl-router/src/routers/pd_router.rs b/sgl-router/src/routers/http/pd_router.rs similarity index 99% rename from sgl-router/src/routers/pd_router.rs rename to sgl-router/src/routers/http/pd_router.rs index 9562c08e4..887be65c4 100644 --- a/sgl-router/src/routers/pd_router.rs +++ b/sgl-router/src/routers/http/pd_router.rs @@ -1,6 +1,5 @@ // PD (Prefill-Decode) Router Implementation // This module handles routing for disaggregated prefill-decode systems -use super::header_utils; use super::pd_types::{api_path, PDRouterError}; use crate::config::types::{ CircuitBreakerConfig as ConfigCircuitBreakerConfig, @@ -16,6 +15,7 @@ use crate::protocols::spec::{ ChatCompletionRequest, ChatMessage, CompletionRequest, GenerateRequest, StringOrArray, UserMessageContent, }; +use crate::routers::header_utils; use crate::routers::{RouterTrait, WorkerManagement}; use async_trait::async_trait; use axum::{ @@ -72,7 +72,7 @@ impl PDRouter { // Private helper method to perform health check on a new server async fn wait_for_server_health(&self, url: &str) -> Result<(), PDRouterError> { - crate::routers::router::Router::wait_for_healthy_workers( + crate::routers::http::router::Router::wait_for_healthy_workers( &[url.to_string()], self.timeout_secs, self.interval_secs, @@ -435,7 +435,7 @@ impl PDRouter { .map(|worker| worker.url().to_string()) .collect(); if !all_urls.is_empty() { - crate::routers::router::Router::wait_for_healthy_workers( + crate::routers::http::router::Router::wait_for_healthy_workers( &all_urls, timeout_secs, interval_secs, @@ -1935,6 +1935,14 @@ impl RouterTrait for PDRouter { self.execute_dual_dispatch(headers, body, context).await } + async fn route_embeddings(&self, _headers: Option<&HeaderMap>, _body: Body) -> Response { + todo!() + } + + async fn route_rerank(&self, _headers: Option<&HeaderMap>, _body: Body) -> Response { + todo!() + } + async fn flush_cache(&self) -> Response { // Process both prefill and decode workers let (prefill_results, prefill_errors) = self @@ -2040,7 +2048,7 @@ impl RouterTrait for PDRouter { let total_decode = self.decode_workers.read().unwrap().len(); if healthy_prefill_count > 0 && healthy_decode_count > 0 { - Json(serde_json::json!({ + Json(json!({ "status": "ready", "prefill": { "healthy": healthy_prefill_count, diff --git a/sgl-router/src/routers/pd_types.rs b/sgl-router/src/routers/http/pd_types.rs similarity index 100% rename from sgl-router/src/routers/pd_types.rs rename to sgl-router/src/routers/http/pd_types.rs diff --git a/sgl-router/src/routers/router.rs b/sgl-router/src/routers/http/router.rs similarity index 99% rename from sgl-router/src/routers/router.rs rename to sgl-router/src/routers/http/router.rs index 077ad6d4f..6e63c7f4a 100644 --- a/sgl-router/src/routers/router.rs +++ b/sgl-router/src/routers/http/router.rs @@ -1,4 +1,3 @@ -use super::header_utils; use crate::config::types::{ CircuitBreakerConfig as ConfigCircuitBreakerConfig, HealthCheckConfig as ConfigHealthCheckConfig, RetryConfig, @@ -12,6 +11,7 @@ use crate::policies::LoadBalancingPolicy; use crate::protocols::spec::{ ChatCompletionRequest, CompletionRequest, GenerateRequest, GenerationRequest, }; +use crate::routers::header_utils; use crate::routers::{RouterTrait, WorkerManagement}; use axum::{ body::Body, @@ -393,7 +393,7 @@ impl Router { // Helper method to proxy GET requests to the first available worker async fn proxy_get_request(&self, req: Request, endpoint: &str) -> Response { - let headers = super::header_utils::copy_request_headers(&req); + let headers = header_utils::copy_request_headers(&req); match self.select_first_worker() { Ok(worker_url) => { @@ -667,7 +667,7 @@ impl Router { if !is_stream { // For non-streaming requests, preserve headers - let response_headers = super::header_utils::preserve_response_headers(res.headers()); + let response_headers = header_utils::preserve_response_headers(res.headers()); let response = match res.bytes().await { Ok(body) => { @@ -1198,6 +1198,14 @@ impl RouterTrait for Router { .await } + async fn route_embeddings(&self, _headers: Option<&HeaderMap>, _body: Body) -> Response { + todo!() + } + + async fn route_rerank(&self, _headers: Option<&HeaderMap>, _body: Body) -> Response { + todo!() + } + async fn flush_cache(&self) -> Response { // Get all worker URLs let worker_urls = self.get_worker_urls(); diff --git a/sgl-router/src/routers/mod.rs b/sgl-router/src/routers/mod.rs index a0882c176..76ef98821 100644 --- a/sgl-router/src/routers/mod.rs +++ b/sgl-router/src/routers/mod.rs @@ -12,10 +12,9 @@ use std::fmt::Debug; use crate::protocols::spec::{ChatCompletionRequest, CompletionRequest, GenerateRequest}; pub mod factory; +pub mod grpc; pub mod header_utils; -pub mod pd_router; -pub mod pd_types; -pub mod router; +pub mod http; pub use factory::RouterFactory; @@ -77,6 +76,10 @@ pub trait RouterTrait: Send + Sync + Debug + WorkerManagement { body: &CompletionRequest, ) -> Response; + async fn route_embeddings(&self, headers: Option<&HeaderMap>, body: Body) -> Response; + + async fn route_rerank(&self, headers: Option<&HeaderMap>, body: Body) -> Response; + /// Flush cache on all workers async fn flush_cache(&self) -> Response; diff --git a/sgl-router/src/service_discovery.rs b/sgl-router/src/service_discovery.rs index 2270671c7..52cdfdea3 100644 --- a/sgl-router/src/service_discovery.rs +++ b/sgl-router/src/service_discovery.rs @@ -383,7 +383,7 @@ async fn handle_pod_event( // Handle PD mode with specific pod types let result = if pd_mode && pod_info.pod_type.is_some() { // Need to import PDRouter type - use crate::routers::pd_router::PDRouter; + use crate::routers::http::pd_router::PDRouter; // Try to downcast to PDRouter if let Some(pd_router) = router.as_any().downcast_ref::() { @@ -453,7 +453,7 @@ async fn handle_pod_deletion( // Handle PD mode removal if pd_mode && pod_info.pod_type.is_some() { - use crate::routers::pd_router::PDRouter; + use crate::routers::http::pd_router::PDRouter; // Try to downcast to PDRouter for PD-specific removal if let Some(pd_router) = router.as_any().downcast_ref::() { @@ -581,7 +581,7 @@ mod tests { async fn create_test_router() -> Arc { use crate::config::PolicyConfig; use crate::policies::PolicyFactory; - use crate::routers::router::Router; + use crate::routers::http::router::Router; let policy = PolicyFactory::create_from_config(&PolicyConfig::Random); let router = Router::new( diff --git a/sgl-router/tests/test_pd_routing.rs b/sgl-router/tests/test_pd_routing.rs index 401ee1119..bcea75a6a 100644 --- a/sgl-router/tests/test_pd_routing.rs +++ b/sgl-router/tests/test_pd_routing.rs @@ -5,8 +5,8 @@ mod test_pd_routing { CircuitBreakerConfig, PolicyConfig, RetryConfig, RouterConfig, RoutingMode, }; use sglang_router_rs::core::{WorkerFactory, WorkerType}; - use sglang_router_rs::routers::pd_types::get_hostname; - use sglang_router_rs::routers::pd_types::PDSelectionPolicy; + use sglang_router_rs::routers::http::pd_types::get_hostname; + use sglang_router_rs::routers::http::pd_types::PDSelectionPolicy; use sglang_router_rs::routers::RouterFactory; // Test-only struct to help validate PD request parsing