[router] Implement HTTP Dependency Injection Pattern for Router System (#8714)
This commit is contained in:
@@ -22,29 +22,34 @@ use tokio::spawn;
|
||||
use tracing::{error, info, warn, Level};
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct AppState {
|
||||
pub router: Arc<dyn RouterTrait>,
|
||||
pub struct AppContext {
|
||||
pub client: Client,
|
||||
pub _concurrency_limiter: Arc<tokio::sync::Semaphore>,
|
||||
pub router_config: RouterConfig,
|
||||
pub concurrency_limiter: Arc<tokio::sync::Semaphore>,
|
||||
// Future dependencies can be added here
|
||||
}
|
||||
|
||||
impl AppState {
|
||||
impl AppContext {
|
||||
pub fn new(
|
||||
router_config: RouterConfig,
|
||||
client: Client,
|
||||
max_concurrent_requests: usize,
|
||||
) -> Result<Self, String> {
|
||||
let router = RouterFactory::create_router(&router_config)?;
|
||||
let router = Arc::from(router);
|
||||
) -> Self {
|
||||
let concurrency_limiter = Arc::new(tokio::sync::Semaphore::new(max_concurrent_requests));
|
||||
Ok(Self {
|
||||
router,
|
||||
Self {
|
||||
client,
|
||||
_concurrency_limiter: concurrency_limiter,
|
||||
})
|
||||
router_config,
|
||||
concurrency_limiter,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct AppState {
|
||||
pub router: Arc<dyn RouterTrait>,
|
||||
pub context: Arc<AppContext>,
|
||||
}
|
||||
|
||||
// Fallback handler for unmatched routes
|
||||
async fn sink_handler() -> Response {
|
||||
StatusCode::NOT_FOUND.into_response()
|
||||
@@ -60,23 +65,23 @@ async fn readiness(State(state): State<Arc<AppState>>) -> Response {
|
||||
}
|
||||
|
||||
async fn health(State(state): State<Arc<AppState>>, req: Request) -> Response {
|
||||
state.router.health(&state.client, req).await
|
||||
state.router.health(req).await
|
||||
}
|
||||
|
||||
async fn health_generate(State(state): State<Arc<AppState>>, req: Request) -> Response {
|
||||
state.router.health_generate(&state.client, req).await
|
||||
state.router.health_generate(req).await
|
||||
}
|
||||
|
||||
async fn get_server_info(State(state): State<Arc<AppState>>, req: Request) -> Response {
|
||||
state.router.get_server_info(&state.client, req).await
|
||||
state.router.get_server_info(req).await
|
||||
}
|
||||
|
||||
async fn v1_models(State(state): State<Arc<AppState>>, req: Request) -> Response {
|
||||
state.router.get_models(&state.client, req).await
|
||||
state.router.get_models(req).await
|
||||
}
|
||||
|
||||
async fn get_model_info(State(state): State<Arc<AppState>>, req: Request) -> Response {
|
||||
state.router.get_model_info(&state.client, req).await
|
||||
state.router.get_model_info(req).await
|
||||
}
|
||||
|
||||
// Generation endpoints
|
||||
@@ -86,10 +91,7 @@ async fn generate(
|
||||
headers: http::HeaderMap,
|
||||
Json(body): Json<GenerateRequest>,
|
||||
) -> Response {
|
||||
state
|
||||
.router
|
||||
.route_generate(&state.client, Some(&headers), &body)
|
||||
.await
|
||||
state.router.route_generate(Some(&headers), &body).await
|
||||
}
|
||||
|
||||
async fn v1_chat_completions(
|
||||
@@ -97,10 +99,7 @@ async fn v1_chat_completions(
|
||||
headers: http::HeaderMap,
|
||||
Json(body): Json<ChatCompletionRequest>,
|
||||
) -> Response {
|
||||
state
|
||||
.router
|
||||
.route_chat(&state.client, Some(&headers), &body)
|
||||
.await
|
||||
state.router.route_chat(Some(&headers), &body).await
|
||||
}
|
||||
|
||||
async fn v1_completions(
|
||||
@@ -108,10 +107,7 @@ async fn v1_completions(
|
||||
headers: http::HeaderMap,
|
||||
Json(body): Json<CompletionRequest>,
|
||||
) -> Response {
|
||||
state
|
||||
.router
|
||||
.route_completion(&state.client, Some(&headers), &body)
|
||||
.await
|
||||
state.router.route_completion(Some(&headers), &body).await
|
||||
}
|
||||
|
||||
// Worker management endpoints
|
||||
@@ -159,11 +155,11 @@ async fn remove_worker(
|
||||
}
|
||||
|
||||
async fn flush_cache(State(state): State<Arc<AppState>>, _req: Request) -> Response {
|
||||
state.router.flush_cache(&state.client).await
|
||||
state.router.flush_cache().await
|
||||
}
|
||||
|
||||
async fn get_loads(State(state): State<Arc<AppState>>, _req: Request) -> Response {
|
||||
state.router.get_worker_loads(&state.client).await
|
||||
state.router.get_worker_loads().await
|
||||
}
|
||||
|
||||
pub struct ServerConfig {
|
||||
@@ -281,11 +277,21 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Err
|
||||
.build()
|
||||
.expect("Failed to create HTTP client");
|
||||
|
||||
let app_state = Arc::new(AppState::new(
|
||||
// Create the application context with all dependencies
|
||||
let app_context = Arc::new(AppContext::new(
|
||||
config.router_config.clone(),
|
||||
client.clone(),
|
||||
config.router_config.max_concurrent_requests,
|
||||
)?);
|
||||
));
|
||||
|
||||
// Create router with the context
|
||||
let router = RouterFactory::create_router(&app_context)?;
|
||||
|
||||
// Create app state with router and context
|
||||
let app_state = Arc::new(AppState {
|
||||
router: Arc::from(router),
|
||||
context: app_context.clone(),
|
||||
});
|
||||
let router_arc = Arc::clone(&app_state.router);
|
||||
|
||||
// Start the service discovery if enabled
|
||||
|
||||
Reference in New Issue
Block a user