286 lines
8.9 KiB
Rust
286 lines
8.9 KiB
Rust
// PD (Prefill-Decode) gRPC Router Implementation
|
|
|
|
use crate::config::types::RetryConfig;
|
|
use crate::core::{ConnectionMode, WorkerRegistry, WorkerType};
|
|
use crate::policies::PolicyRegistry;
|
|
use crate::protocols::spec::{
|
|
ChatCompletionRequest, CompletionRequest, EmbeddingRequest, GenerateRequest, RerankRequest,
|
|
ResponsesGetParams, ResponsesRequest,
|
|
};
|
|
use crate::reasoning_parser::ReasoningParserFactory;
|
|
use crate::routers::RouterTrait;
|
|
use crate::server::AppContext;
|
|
use crate::tokenizer::traits::Tokenizer;
|
|
use crate::tool_parser::ToolParserFactory;
|
|
use async_trait::async_trait;
|
|
use axum::{
|
|
body::Body,
|
|
extract::Request,
|
|
http::{HeaderMap, StatusCode},
|
|
response::{IntoResponse, Response},
|
|
};
|
|
use std::sync::Arc;
|
|
|
|
use tracing::debug;
|
|
|
|
/// gRPC PD (Prefill-Decode) router implementation for SGLang
|
|
#[derive(Clone)]
|
|
#[allow(dead_code)] // Fields will be used once implementation is complete
|
|
pub struct GrpcPDRouter {
|
|
worker_registry: Arc<WorkerRegistry>,
|
|
policy_registry: Arc<PolicyRegistry>,
|
|
tokenizer: Arc<dyn Tokenizer>,
|
|
reasoning_parser_factory: ReasoningParserFactory,
|
|
tool_parser_factory: ToolParserFactory,
|
|
dp_aware: bool,
|
|
api_key: Option<String>,
|
|
retry_config: RetryConfig,
|
|
configured_reasoning_parser: Option<String>,
|
|
configured_tool_parser: Option<String>,
|
|
pipeline: super::pipeline::ChatCompletionPipeline,
|
|
shared_components: Arc<super::context::SharedComponents>,
|
|
}
|
|
|
|
impl GrpcPDRouter {
|
|
/// Create a new gRPC PD router
|
|
pub async fn new(ctx: &Arc<AppContext>) -> Result<Self, String> {
|
|
// Get registries from context
|
|
let worker_registry = ctx.worker_registry.clone();
|
|
let policy_registry = ctx.policy_registry.clone();
|
|
|
|
// Extract necessary components from context
|
|
let tokenizer = ctx
|
|
.tokenizer
|
|
.as_ref()
|
|
.ok_or_else(|| "gRPC PD router requires tokenizer".to_string())?
|
|
.clone();
|
|
let reasoning_parser_factory = ctx
|
|
.reasoning_parser_factory
|
|
.as_ref()
|
|
.ok_or_else(|| "gRPC PD router requires reasoning parser factory".to_string())?
|
|
.clone();
|
|
let tool_parser_factory = ctx
|
|
.tool_parser_factory
|
|
.as_ref()
|
|
.ok_or_else(|| "gRPC PD router requires tool parser factory".to_string())?
|
|
.clone();
|
|
|
|
// Create shared components for pipeline
|
|
let shared_components = Arc::new(super::context::SharedComponents {
|
|
tokenizer: tokenizer.clone(),
|
|
tool_parser_factory: tool_parser_factory.clone(),
|
|
reasoning_parser_factory: reasoning_parser_factory.clone(),
|
|
});
|
|
|
|
// Create response processor
|
|
let processor = super::processing::ResponseProcessor::new(
|
|
tokenizer.clone(),
|
|
tool_parser_factory.clone(),
|
|
reasoning_parser_factory.clone(),
|
|
ctx.configured_tool_parser.clone(),
|
|
ctx.configured_reasoning_parser.clone(),
|
|
);
|
|
|
|
// Create streaming processor
|
|
let streaming_processor = Arc::new(super::streaming::StreamingProcessor::new(
|
|
tokenizer.clone(),
|
|
tool_parser_factory.clone(),
|
|
reasoning_parser_factory.clone(),
|
|
ctx.configured_tool_parser.clone(),
|
|
ctx.configured_reasoning_parser.clone(),
|
|
));
|
|
|
|
// Create PD pipeline
|
|
let pipeline = super::pipeline::ChatCompletionPipeline::new_pd(
|
|
worker_registry.clone(),
|
|
policy_registry.clone(),
|
|
processor,
|
|
streaming_processor,
|
|
);
|
|
|
|
Ok(GrpcPDRouter {
|
|
worker_registry,
|
|
policy_registry,
|
|
tokenizer,
|
|
reasoning_parser_factory,
|
|
tool_parser_factory,
|
|
dp_aware: ctx.router_config.dp_aware,
|
|
api_key: ctx.router_config.api_key.clone(),
|
|
retry_config: ctx.router_config.effective_retry_config(),
|
|
configured_reasoning_parser: ctx.configured_reasoning_parser.clone(),
|
|
configured_tool_parser: ctx.configured_tool_parser.clone(),
|
|
pipeline,
|
|
shared_components,
|
|
})
|
|
}
|
|
|
|
/// Main route_generate implementation with PD dual dispatch
|
|
async fn route_generate_impl(
|
|
&self,
|
|
headers: Option<&HeaderMap>,
|
|
body: &GenerateRequest,
|
|
model_id: Option<&str>,
|
|
) -> Response {
|
|
debug!(
|
|
"Processing generate request for model: {:?} (PD mode)",
|
|
model_id
|
|
);
|
|
|
|
// Use pipeline for ALL requests (streaming and non-streaming)
|
|
self.pipeline
|
|
.execute_generate(
|
|
Arc::new(body.clone()),
|
|
headers.cloned(),
|
|
model_id.map(|s| s.to_string()),
|
|
self.shared_components.clone(),
|
|
)
|
|
.await
|
|
}
|
|
|
|
/// Main route_chat implementation with PD dual dispatch
|
|
async fn route_chat_impl(
|
|
&self,
|
|
headers: Option<&HeaderMap>,
|
|
body: &ChatCompletionRequest,
|
|
model_id: Option<&str>,
|
|
) -> Response {
|
|
debug!(
|
|
"Processing chat completion request for model: {:?} (PD mode)",
|
|
model_id
|
|
);
|
|
|
|
// Use pipeline for ALL requests (streaming and non-streaming)
|
|
self.pipeline
|
|
.execute_chat(
|
|
Arc::new(body.clone()),
|
|
headers.cloned(),
|
|
model_id.map(|s| s.to_string()),
|
|
self.shared_components.clone(),
|
|
)
|
|
.await
|
|
}
|
|
}
|
|
|
|
impl std::fmt::Debug for GrpcPDRouter {
|
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
|
let prefill_workers = self.worker_registry.get_workers_filtered(
|
|
None,
|
|
Some(WorkerType::Prefill {
|
|
bootstrap_port: None,
|
|
}),
|
|
Some(ConnectionMode::Grpc { port: None }),
|
|
false,
|
|
);
|
|
let decode_workers = self.worker_registry.get_workers_filtered(
|
|
None,
|
|
Some(WorkerType::Decode),
|
|
Some(ConnectionMode::Grpc { port: None }),
|
|
false,
|
|
);
|
|
f.debug_struct("GrpcPDRouter")
|
|
.field("prefill_workers_count", &prefill_workers.len())
|
|
.field("decode_workers_count", &decode_workers.len())
|
|
.field("dp_aware", &self.dp_aware)
|
|
.finish()
|
|
}
|
|
}
|
|
|
|
#[async_trait]
|
|
impl RouterTrait for GrpcPDRouter {
|
|
fn as_any(&self) -> &dyn std::any::Any {
|
|
self
|
|
}
|
|
|
|
async fn health_generate(&self, _req: Request<Body>) -> Response {
|
|
// TODO: Implement actual generation test for gRPC PD mode
|
|
(
|
|
StatusCode::NOT_IMPLEMENTED,
|
|
"Health generate not yet implemented for gRPC PD",
|
|
)
|
|
.into_response()
|
|
}
|
|
|
|
async fn get_server_info(&self, _req: Request<Body>) -> Response {
|
|
(StatusCode::NOT_IMPLEMENTED).into_response()
|
|
}
|
|
|
|
async fn get_models(&self, _req: Request<Body>) -> Response {
|
|
(StatusCode::NOT_IMPLEMENTED).into_response()
|
|
}
|
|
|
|
async fn get_model_info(&self, _req: Request<Body>) -> Response {
|
|
(StatusCode::NOT_IMPLEMENTED).into_response()
|
|
}
|
|
|
|
async fn route_generate(
|
|
&self,
|
|
headers: Option<&HeaderMap>,
|
|
body: &GenerateRequest,
|
|
model_id: Option<&str>,
|
|
) -> Response {
|
|
self.route_generate_impl(headers, body, model_id).await
|
|
}
|
|
|
|
async fn route_chat(
|
|
&self,
|
|
headers: Option<&HeaderMap>,
|
|
body: &ChatCompletionRequest,
|
|
model_id: Option<&str>,
|
|
) -> Response {
|
|
self.route_chat_impl(headers, body, model_id).await
|
|
}
|
|
|
|
async fn route_completion(
|
|
&self,
|
|
_headers: Option<&HeaderMap>,
|
|
_body: &CompletionRequest,
|
|
_model_id: Option<&str>,
|
|
) -> Response {
|
|
(StatusCode::NOT_IMPLEMENTED).into_response()
|
|
}
|
|
|
|
async fn route_responses(
|
|
&self,
|
|
_headers: Option<&HeaderMap>,
|
|
_body: &ResponsesRequest,
|
|
_model_id: Option<&str>,
|
|
) -> Response {
|
|
(StatusCode::NOT_IMPLEMENTED).into_response()
|
|
}
|
|
|
|
async fn get_response(
|
|
&self,
|
|
_headers: Option<&HeaderMap>,
|
|
_response_id: &str,
|
|
_params: &ResponsesGetParams,
|
|
) -> Response {
|
|
(StatusCode::NOT_IMPLEMENTED).into_response()
|
|
}
|
|
|
|
async fn cancel_response(&self, _headers: Option<&HeaderMap>, _response_id: &str) -> Response {
|
|
(StatusCode::NOT_IMPLEMENTED).into_response()
|
|
}
|
|
|
|
async fn route_embeddings(
|
|
&self,
|
|
_headers: Option<&HeaderMap>,
|
|
_body: &EmbeddingRequest,
|
|
_model_id: Option<&str>,
|
|
) -> Response {
|
|
(StatusCode::NOT_IMPLEMENTED).into_response()
|
|
}
|
|
|
|
async fn route_rerank(
|
|
&self,
|
|
_headers: Option<&HeaderMap>,
|
|
_body: &RerankRequest,
|
|
_model_id: Option<&str>,
|
|
) -> Response {
|
|
(StatusCode::NOT_IMPLEMENTED).into_response()
|
|
}
|
|
|
|
fn router_type(&self) -> &'static str {
|
|
"grpc_pd"
|
|
}
|
|
}
|