Files
sglang/sgl-router/src/routers/grpc/pd_router.rs

286 lines
8.9 KiB
Rust

// PD (Prefill-Decode) gRPC Router Implementation
use crate::config::types::RetryConfig;
use crate::core::{ConnectionMode, WorkerRegistry, WorkerType};
use crate::policies::PolicyRegistry;
use crate::protocols::spec::{
ChatCompletionRequest, CompletionRequest, EmbeddingRequest, GenerateRequest, RerankRequest,
ResponsesGetParams, ResponsesRequest,
};
use crate::reasoning_parser::ReasoningParserFactory;
use crate::routers::RouterTrait;
use crate::server::AppContext;
use crate::tokenizer::traits::Tokenizer;
use crate::tool_parser::ToolParserFactory;
use async_trait::async_trait;
use axum::{
body::Body,
extract::Request,
http::{HeaderMap, StatusCode},
response::{IntoResponse, Response},
};
use std::sync::Arc;
use tracing::debug;
/// gRPC PD (Prefill-Decode) router implementation for SGLang
#[derive(Clone)]
#[allow(dead_code)] // Fields will be used once implementation is complete
pub struct GrpcPDRouter {
worker_registry: Arc<WorkerRegistry>,
policy_registry: Arc<PolicyRegistry>,
tokenizer: Arc<dyn Tokenizer>,
reasoning_parser_factory: ReasoningParserFactory,
tool_parser_factory: ToolParserFactory,
dp_aware: bool,
api_key: Option<String>,
retry_config: RetryConfig,
configured_reasoning_parser: Option<String>,
configured_tool_parser: Option<String>,
pipeline: super::pipeline::ChatCompletionPipeline,
shared_components: Arc<super::context::SharedComponents>,
}
impl GrpcPDRouter {
/// Create a new gRPC PD router
pub async fn new(ctx: &Arc<AppContext>) -> Result<Self, String> {
// Get registries from context
let worker_registry = ctx.worker_registry.clone();
let policy_registry = ctx.policy_registry.clone();
// Extract necessary components from context
let tokenizer = ctx
.tokenizer
.as_ref()
.ok_or_else(|| "gRPC PD router requires tokenizer".to_string())?
.clone();
let reasoning_parser_factory = ctx
.reasoning_parser_factory
.as_ref()
.ok_or_else(|| "gRPC PD router requires reasoning parser factory".to_string())?
.clone();
let tool_parser_factory = ctx
.tool_parser_factory
.as_ref()
.ok_or_else(|| "gRPC PD router requires tool parser factory".to_string())?
.clone();
// Create shared components for pipeline
let shared_components = Arc::new(super::context::SharedComponents {
tokenizer: tokenizer.clone(),
tool_parser_factory: tool_parser_factory.clone(),
reasoning_parser_factory: reasoning_parser_factory.clone(),
});
// Create response processor
let processor = super::processing::ResponseProcessor::new(
tokenizer.clone(),
tool_parser_factory.clone(),
reasoning_parser_factory.clone(),
ctx.configured_tool_parser.clone(),
ctx.configured_reasoning_parser.clone(),
);
// Create streaming processor
let streaming_processor = Arc::new(super::streaming::StreamingProcessor::new(
tokenizer.clone(),
tool_parser_factory.clone(),
reasoning_parser_factory.clone(),
ctx.configured_tool_parser.clone(),
ctx.configured_reasoning_parser.clone(),
));
// Create PD pipeline
let pipeline = super::pipeline::ChatCompletionPipeline::new_pd(
worker_registry.clone(),
policy_registry.clone(),
processor,
streaming_processor,
);
Ok(GrpcPDRouter {
worker_registry,
policy_registry,
tokenizer,
reasoning_parser_factory,
tool_parser_factory,
dp_aware: ctx.router_config.dp_aware,
api_key: ctx.router_config.api_key.clone(),
retry_config: ctx.router_config.effective_retry_config(),
configured_reasoning_parser: ctx.configured_reasoning_parser.clone(),
configured_tool_parser: ctx.configured_tool_parser.clone(),
pipeline,
shared_components,
})
}
/// Main route_generate implementation with PD dual dispatch
async fn route_generate_impl(
&self,
headers: Option<&HeaderMap>,
body: &GenerateRequest,
model_id: Option<&str>,
) -> Response {
debug!(
"Processing generate request for model: {:?} (PD mode)",
model_id
);
// Use pipeline for ALL requests (streaming and non-streaming)
self.pipeline
.execute_generate(
Arc::new(body.clone()),
headers.cloned(),
model_id.map(|s| s.to_string()),
self.shared_components.clone(),
)
.await
}
/// Main route_chat implementation with PD dual dispatch
async fn route_chat_impl(
&self,
headers: Option<&HeaderMap>,
body: &ChatCompletionRequest,
model_id: Option<&str>,
) -> Response {
debug!(
"Processing chat completion request for model: {:?} (PD mode)",
model_id
);
// Use pipeline for ALL requests (streaming and non-streaming)
self.pipeline
.execute_chat(
Arc::new(body.clone()),
headers.cloned(),
model_id.map(|s| s.to_string()),
self.shared_components.clone(),
)
.await
}
}
impl std::fmt::Debug for GrpcPDRouter {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let prefill_workers = self.worker_registry.get_workers_filtered(
None,
Some(WorkerType::Prefill {
bootstrap_port: None,
}),
Some(ConnectionMode::Grpc { port: None }),
false,
);
let decode_workers = self.worker_registry.get_workers_filtered(
None,
Some(WorkerType::Decode),
Some(ConnectionMode::Grpc { port: None }),
false,
);
f.debug_struct("GrpcPDRouter")
.field("prefill_workers_count", &prefill_workers.len())
.field("decode_workers_count", &decode_workers.len())
.field("dp_aware", &self.dp_aware)
.finish()
}
}
#[async_trait]
impl RouterTrait for GrpcPDRouter {
fn as_any(&self) -> &dyn std::any::Any {
self
}
async fn health_generate(&self, _req: Request<Body>) -> Response {
// TODO: Implement actual generation test for gRPC PD mode
(
StatusCode::NOT_IMPLEMENTED,
"Health generate not yet implemented for gRPC PD",
)
.into_response()
}
async fn get_server_info(&self, _req: Request<Body>) -> Response {
(StatusCode::NOT_IMPLEMENTED).into_response()
}
async fn get_models(&self, _req: Request<Body>) -> Response {
(StatusCode::NOT_IMPLEMENTED).into_response()
}
async fn get_model_info(&self, _req: Request<Body>) -> Response {
(StatusCode::NOT_IMPLEMENTED).into_response()
}
async fn route_generate(
&self,
headers: Option<&HeaderMap>,
body: &GenerateRequest,
model_id: Option<&str>,
) -> Response {
self.route_generate_impl(headers, body, model_id).await
}
async fn route_chat(
&self,
headers: Option<&HeaderMap>,
body: &ChatCompletionRequest,
model_id: Option<&str>,
) -> Response {
self.route_chat_impl(headers, body, model_id).await
}
async fn route_completion(
&self,
_headers: Option<&HeaderMap>,
_body: &CompletionRequest,
_model_id: Option<&str>,
) -> Response {
(StatusCode::NOT_IMPLEMENTED).into_response()
}
async fn route_responses(
&self,
_headers: Option<&HeaderMap>,
_body: &ResponsesRequest,
_model_id: Option<&str>,
) -> Response {
(StatusCode::NOT_IMPLEMENTED).into_response()
}
async fn get_response(
&self,
_headers: Option<&HeaderMap>,
_response_id: &str,
_params: &ResponsesGetParams,
) -> Response {
(StatusCode::NOT_IMPLEMENTED).into_response()
}
async fn cancel_response(&self, _headers: Option<&HeaderMap>, _response_id: &str) -> Response {
(StatusCode::NOT_IMPLEMENTED).into_response()
}
async fn route_embeddings(
&self,
_headers: Option<&HeaderMap>,
_body: &EmbeddingRequest,
_model_id: Option<&str>,
) -> Response {
(StatusCode::NOT_IMPLEMENTED).into_response()
}
async fn route_rerank(
&self,
_headers: Option<&HeaderMap>,
_body: &RerankRequest,
_model_id: Option<&str>,
) -> Response {
(StatusCode::NOT_IMPLEMENTED).into_response()
}
fn router_type(&self) -> &'static str {
"grpc_pd"
}
}