diff --git a/sgl-router/src/config/types.rs b/sgl-router/src/config/types.rs index a45d52bd2..5f5b227af 100644 --- a/sgl-router/src/config/types.rs +++ b/sgl-router/src/config/types.rs @@ -101,6 +101,11 @@ pub enum RoutingMode { #[serde(skip_serializing_if = "Option::is_none")] decode_policy: Option, }, + #[serde(rename = "openai")] + OpenAI { + /// OpenAI-compatible API base(s), provided via worker URLs + worker_urls: Vec, + }, } impl RoutingMode { @@ -116,6 +121,8 @@ impl RoutingMode { decode_urls, .. } => prefill_urls.len() + decode_urls.len(), + // OpenAI mode represents a single upstream + RoutingMode::OpenAI { .. } => 1, } } @@ -380,6 +387,7 @@ impl RouterConfig { match self.mode { RoutingMode::Regular { .. } => "regular", RoutingMode::PrefillDecode { .. } => "prefill_decode", + RoutingMode::OpenAI { .. } => "openai", } } diff --git a/sgl-router/src/config/validation.rs b/sgl-router/src/config/validation.rs index a0a31fd23..710ad3fc8 100644 --- a/sgl-router/src/config/validation.rs +++ b/sgl-router/src/config/validation.rs @@ -95,6 +95,20 @@ impl ConfigValidator { Self::validate_policy(d_policy)?; } } + RoutingMode::OpenAI { worker_urls } => { + // Require exactly one worker URL for OpenAI router + if worker_urls.len() != 1 { + return Err(ConfigError::ValidationFailed { + reason: "OpenAI mode requires exactly one --worker-urls entry".to_string(), + }); + } + // Validate URL format + if let Err(e) = url::Url::parse(&worker_urls[0]) { + return Err(ConfigError::ValidationFailed { + reason: format!("Invalid OpenAI worker URL '{}': {}", &worker_urls[0], e), + }); + } + } } Ok(()) } @@ -243,6 +257,12 @@ impl ConfigValidator { }); } } + RoutingMode::OpenAI { .. } => { + // OpenAI mode doesn't use service discovery + return Err(ConfigError::ValidationFailed { + reason: "OpenAI mode does not support service discovery".to_string(), + }); + } } Ok(()) diff --git a/sgl-router/src/main.rs b/sgl-router/src/main.rs index 60986bbea..8ec24a722 100644 --- a/sgl-router/src/main.rs +++ b/sgl-router/src/main.rs @@ -1,4 +1,4 @@ -use clap::{ArgAction, Parser}; +use clap::{ArgAction, Parser, ValueEnum}; use sglang_router_rs::config::{ CircuitBreakerConfig, ConfigError, ConfigResult, ConnectionMode, DiscoveryConfig, HealthCheckConfig, MetricsConfig, PolicyConfig, RetryConfig, RouterConfig, RoutingMode, @@ -41,6 +41,33 @@ fn parse_prefill_args() -> Vec<(String, Option)> { prefill_entries } +#[derive(Copy, Clone, Debug, Eq, PartialEq, ValueEnum)] +pub enum Backend { + #[value(name = "sglang")] + Sglang, + #[value(name = "vllm")] + Vllm, + #[value(name = "trtllm")] + Trtllm, + #[value(name = "openai")] + Openai, + #[value(name = "anthropic")] + Anthropic, +} + +impl std::fmt::Display for Backend { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let s = match self { + Backend::Sglang => "sglang", + Backend::Vllm => "vllm", + Backend::Trtllm => "trtllm", + Backend::Openai => "openai", + Backend::Anthropic => "anthropic", + }; + write!(f, "{}", s) + } +} + #[derive(Parser, Debug)] #[command(name = "sglang-router")] #[command(about = "SGLang Router - High-performance request distribution across worker nodes")] @@ -145,6 +172,10 @@ struct CliArgs { #[arg(long)] api_key: Option, + /// Backend to route requests to (sglang, vllm, trtllm, openai, anthropic) + #[arg(long, value_enum, default_value_t = Backend::Sglang, alias = "runtime")] + backend: Backend, + /// Directory to store log files #[arg(long)] log_dir: Option, @@ -339,6 +370,11 @@ impl CliArgs { RoutingMode::Regular { worker_urls: vec![], } + } else if matches!(self.backend, Backend::Openai) { + // OpenAI backend mode - use worker_urls as base(s) + RoutingMode::OpenAI { + worker_urls: self.worker_urls.clone(), + } } else if self.pd_disaggregation { let decode_urls = self.decode.clone(); @@ -409,8 +445,14 @@ impl CliArgs { } all_urls.extend(decode_urls.clone()); } + RoutingMode::OpenAI { .. } => { + // For connection-mode detection, skip URLs; OpenAI forces HTTP below. + } } - let connection_mode = Self::determine_connection_mode(&all_urls); + let connection_mode = match &mode { + RoutingMode::OpenAI { .. } => ConnectionMode::Http, + _ => Self::determine_connection_mode(&all_urls), + }; // Build RouterConfig Ok(RouterConfig { @@ -543,16 +585,28 @@ fn main() -> Result<(), Box> { // Print startup info println!("SGLang Router starting..."); println!("Host: {}:{}", cli_args.host, cli_args.port); - println!( - "Mode: {}", - if cli_args.enable_igw { - "IGW (Inference Gateway)" - } else if cli_args.pd_disaggregation { - "PD Disaggregated" - } else { - "Regular" + let mode_str = if cli_args.enable_igw { + "IGW (Inference Gateway)".to_string() + } else if matches!(cli_args.backend, Backend::Openai) { + "OpenAI Backend".to_string() + } else if cli_args.pd_disaggregation { + "PD Disaggregated".to_string() + } else { + format!("Regular ({})", cli_args.backend) + }; + println!("Mode: {}", mode_str); + + // Warn for runtimes that are parsed but not yet implemented + match cli_args.backend { + Backend::Vllm | Backend::Trtllm | Backend::Anthropic => { + println!( + "WARNING: runtime '{}' not implemented yet; falling back to regular routing. \ +Provide --worker-urls or PD flags as usual.", + cli_args.backend + ); } - ); + Backend::Sglang | Backend::Openai => {} + } if !cli_args.enable_igw { println!("Policy: {}", cli_args.policy); diff --git a/sgl-router/src/routers/factory.rs b/sgl-router/src/routers/factory.rs index 05bb459de..d1bdc0fce 100644 --- a/sgl-router/src/routers/factory.rs +++ b/sgl-router/src/routers/factory.rs @@ -1,7 +1,7 @@ //! Factory for creating router instances use super::{ - http::{pd_router::PDRouter, router::Router}, + http::{openai_router::OpenAIRouter, pd_router::PDRouter, router::Router}, RouterTrait, }; use crate::config::{ConnectionMode, PolicyConfig, RoutingMode}; @@ -44,6 +44,9 @@ impl RouterFactory { ) .await } + RoutingMode::OpenAI { .. } => { + Err("OpenAI mode requires HTTP connection_mode".to_string()) + } } } ConnectionMode::Http => { @@ -69,6 +72,9 @@ impl RouterFactory { ) .await } + RoutingMode::OpenAI { worker_urls, .. } => { + Self::create_openai_router(worker_urls.clone(), ctx).await + } } } } @@ -164,6 +170,23 @@ impl RouterFactory { Ok(Box::new(router)) } + /// Create an OpenAI router + async fn create_openai_router( + worker_urls: Vec, + ctx: &Arc, + ) -> Result, String> { + // Use the first worker URL as the OpenAI-compatible base + let base_url = worker_urls + .first() + .cloned() + .ok_or_else(|| "OpenAI mode requires at least one worker URL".to_string())?; + + let router = + OpenAIRouter::new(base_url, Some(ctx.router_config.circuit_breaker.clone())).await?; + + Ok(Box::new(router)) + } + /// Create an IGW router (placeholder for future implementation) async fn create_igw_router(_ctx: &Arc) -> Result, String> { // For now, return an error indicating IGW is not yet implemented diff --git a/sgl-router/src/routers/http/mod.rs b/sgl-router/src/routers/http/mod.rs index 3f31b6f86..9f955b651 100644 --- a/sgl-router/src/routers/http/mod.rs +++ b/sgl-router/src/routers/http/mod.rs @@ -1,5 +1,6 @@ //! HTTP router implementations +pub mod openai_router; pub mod pd_router; pub mod pd_types; pub mod router; diff --git a/sgl-router/src/routers/http/openai_router.rs b/sgl-router/src/routers/http/openai_router.rs new file mode 100644 index 000000000..551dd1aa3 --- /dev/null +++ b/sgl-router/src/routers/http/openai_router.rs @@ -0,0 +1,379 @@ +//! OpenAI router implementation (reqwest-based) + +use crate::config::CircuitBreakerConfig; +use crate::core::{CircuitBreaker, CircuitBreakerConfig as CoreCircuitBreakerConfig}; +use crate::protocols::spec::{ChatCompletionRequest, CompletionRequest, GenerateRequest}; +use async_trait::async_trait; +use axum::{ + body::Body, + extract::Request, + http::{header::CONTENT_TYPE, HeaderMap, HeaderValue, StatusCode}, + response::{IntoResponse, Response}, +}; +use futures_util::StreamExt; +use std::{ + any::Any, + sync::atomic::{AtomicBool, Ordering}, +}; + +/// Router for OpenAI backend +#[derive(Debug)] +pub struct OpenAIRouter { + /// HTTP client for upstream OpenAI-compatible API + client: reqwest::Client, + /// Base URL for identification (no trailing slash) + base_url: String, + /// Circuit breaker + circuit_breaker: CircuitBreaker, + /// Health status + healthy: AtomicBool, +} + +impl OpenAIRouter { + /// Create a new OpenAI router + pub async fn new( + base_url: String, + circuit_breaker_config: Option, + ) -> Result { + let client = reqwest::Client::builder() + .timeout(std::time::Duration::from_secs(300)) + .build() + .map_err(|e| format!("Failed to create HTTP client: {}", e))?; + + let base_url = base_url.trim_end_matches('/').to_string(); + + // Convert circuit breaker config + let core_cb_config = circuit_breaker_config + .map(|cb| CoreCircuitBreakerConfig { + failure_threshold: cb.failure_threshold, + success_threshold: cb.success_threshold, + timeout_duration: std::time::Duration::from_secs(cb.timeout_duration_secs), + window_duration: std::time::Duration::from_secs(cb.window_duration_secs), + }) + .unwrap_or_default(); + + let circuit_breaker = CircuitBreaker::with_config(core_cb_config); + + Ok(Self { + client, + base_url, + circuit_breaker, + healthy: AtomicBool::new(true), + }) + } +} + +#[async_trait] +impl super::super::WorkerManagement for OpenAIRouter { + async fn add_worker(&self, _worker_url: &str) -> Result { + Err("Cannot add workers to OpenAI router".to_string()) + } + + fn remove_worker(&self, _worker_url: &str) { + // No-op for OpenAI router + } + + fn get_worker_urls(&self) -> Vec { + vec![self.base_url.clone()] + } +} + +#[async_trait] +impl super::super::RouterTrait for OpenAIRouter { + fn as_any(&self) -> &dyn Any { + self + } + + async fn health(&self, _req: Request) -> Response { + // Simple upstream probe: GET {base}/v1/models without auth + let url = format!("{}/v1/models", self.base_url); + match self + .client + .get(&url) + .timeout(std::time::Duration::from_secs(2)) + .send() + .await + { + Ok(resp) => { + let code = resp.status(); + // Treat success and auth-required as healthy (endpoint reachable) + if code.is_success() || code.as_u16() == 401 || code.as_u16() == 403 { + (StatusCode::OK, "OK").into_response() + } else { + ( + StatusCode::SERVICE_UNAVAILABLE, + format!("Upstream status: {}", code), + ) + .into_response() + } + } + Err(e) => ( + StatusCode::SERVICE_UNAVAILABLE, + format!("Upstream error: {}", e), + ) + .into_response(), + } + } + + async fn health_generate(&self, _req: Request) -> Response { + // For OpenAI, health_generate is the same as health + self.health(_req).await + } + + async fn get_server_info(&self, _req: Request) -> Response { + let info = serde_json::json!({ + "router_type": "openai", + "workers": 1, + "base_url": &self.base_url + }); + (StatusCode::OK, info.to_string()).into_response() + } + + async fn get_models(&self, req: Request) -> Response { + // Proxy to upstream /v1/models; forward Authorization header if provided + let headers = req.headers(); + + let mut upstream = self.client.get(format!("{}/v1/models", self.base_url)); + + if let Some(auth) = headers + .get("authorization") + .or_else(|| headers.get("Authorization")) + { + upstream = upstream.header("Authorization", auth); + } + + match upstream.send().await { + Ok(res) => { + let status = StatusCode::from_u16(res.status().as_u16()) + .unwrap_or(StatusCode::INTERNAL_SERVER_ERROR); + let content_type = res.headers().get(CONTENT_TYPE).cloned(); + match res.bytes().await { + Ok(body) => { + let mut response = Response::new(axum::body::Body::from(body)); + *response.status_mut() = status; + if let Some(ct) = content_type { + response.headers_mut().insert(CONTENT_TYPE, ct); + } + response + } + Err(e) => ( + StatusCode::INTERNAL_SERVER_ERROR, + format!("Failed to read upstream response: {}", e), + ) + .into_response(), + } + } + Err(e) => ( + StatusCode::BAD_GATEWAY, + format!("Failed to contact upstream: {}", e), + ) + .into_response(), + } + } + + async fn get_model_info(&self, _req: Request) -> Response { + // Not directly supported without model param; return 501 + ( + StatusCode::NOT_IMPLEMENTED, + "get_model_info not implemented for OpenAI router", + ) + .into_response() + } + + async fn route_generate( + &self, + _headers: Option<&HeaderMap>, + _body: &GenerateRequest, + ) -> Response { + // Generate endpoint is SGLang-specific, not supported for OpenAI backend + ( + StatusCode::NOT_IMPLEMENTED, + "Generate endpoint not supported for OpenAI backend", + ) + .into_response() + } + + async fn route_chat( + &self, + headers: Option<&HeaderMap>, + body: &ChatCompletionRequest, + ) -> Response { + if !self.circuit_breaker.can_execute() { + return (StatusCode::SERVICE_UNAVAILABLE, "Circuit breaker open").into_response(); + } + + // Serialize request body, removing SGLang-only fields + let mut payload = match serde_json::to_value(body) { + Ok(v) => v, + Err(e) => { + return ( + StatusCode::BAD_REQUEST, + format!("Failed to serialize request: {}", e), + ) + .into_response(); + } + }; + if let Some(obj) = payload.as_object_mut() { + for key in [ + "top_k", + "min_p", + "min_tokens", + "regex", + "ebnf", + "stop_token_ids", + "no_stop_trim", + "ignore_eos", + "continue_final_message", + "skip_special_tokens", + "lora_path", + "session_params", + "separate_reasoning", + "stream_reasoning", + "chat_template_kwargs", + "return_hidden_states", + "repetition_penalty", + ] { + obj.remove(key); + } + } + + let url = format!("{}/v1/chat/completions", self.base_url); + let mut req = self.client.post(&url).json(&payload); + + // Forward Authorization header if provided + if let Some(h) = headers { + if let Some(auth) = h.get("authorization").or_else(|| h.get("Authorization")) { + req = req.header("Authorization", auth); + } + } + + // Accept SSE when stream=true + if body.stream { + req = req.header("Accept", "text/event-stream"); + } + + let resp = match req.send().await { + Ok(r) => r, + Err(e) => { + self.circuit_breaker.record_failure(); + return ( + StatusCode::SERVICE_UNAVAILABLE, + format!("Failed to contact upstream: {}", e), + ) + .into_response(); + } + }; + + let status = StatusCode::from_u16(resp.status().as_u16()) + .unwrap_or(StatusCode::INTERNAL_SERVER_ERROR); + + if !body.stream { + // Capture Content-Type before consuming response body + let content_type = resp.headers().get(CONTENT_TYPE).cloned(); + match resp.bytes().await { + Ok(body) => { + self.circuit_breaker.record_success(); + let mut response = Response::new(axum::body::Body::from(body)); + *response.status_mut() = status; + if let Some(ct) = content_type { + response.headers_mut().insert(CONTENT_TYPE, ct); + } + response + } + Err(e) => { + self.circuit_breaker.record_failure(); + ( + StatusCode::INTERNAL_SERVER_ERROR, + format!("Failed to read response: {}", e), + ) + .into_response() + } + } + } else { + // Stream SSE bytes to client + let stream = resp.bytes_stream(); + let (tx, rx) = tokio::sync::mpsc::unbounded_channel(); + tokio::spawn(async move { + let mut s = stream; + while let Some(chunk) = s.next().await { + match chunk { + Ok(bytes) => { + if tx.send(Ok(bytes)).is_err() { + break; + } + } + Err(e) => { + let _ = tx.send(Err(format!("Stream error: {}", e))); + break; + } + } + } + }); + let mut response = Response::new(Body::from_stream( + tokio_stream::wrappers::UnboundedReceiverStream::new(rx), + )); + *response.status_mut() = status; + response + .headers_mut() + .insert(CONTENT_TYPE, HeaderValue::from_static("text/event-stream")); + response + } + } + + async fn route_completion( + &self, + _headers: Option<&HeaderMap>, + _body: &CompletionRequest, + ) -> Response { + // Completion endpoint not implemented for OpenAI backend + ( + StatusCode::NOT_IMPLEMENTED, + "Completion endpoint not implemented for OpenAI backend", + ) + .into_response() + } + + async fn flush_cache(&self) -> Response { + ( + StatusCode::NOT_IMPLEMENTED, + "flush_cache not supported for OpenAI router", + ) + .into_response() + } + + async fn get_worker_loads(&self) -> Response { + ( + StatusCode::NOT_IMPLEMENTED, + "get_worker_loads not supported for OpenAI router", + ) + .into_response() + } + + fn router_type(&self) -> &'static str { + "openai" + } + + fn readiness(&self) -> Response { + if self.healthy.load(Ordering::Acquire) && self.circuit_breaker.can_execute() { + (StatusCode::OK, "Ready").into_response() + } else { + (StatusCode::SERVICE_UNAVAILABLE, "Not ready").into_response() + } + } + + async fn route_embeddings(&self, _headers: Option<&HeaderMap>, _body: Body) -> Response { + ( + StatusCode::NOT_IMPLEMENTED, + "Embeddings endpoint not implemented for OpenAI backend", + ) + .into_response() + } + + async fn route_rerank(&self, _headers: Option<&HeaderMap>, _body: Body) -> Response { + ( + StatusCode::NOT_IMPLEMENTED, + "Rerank endpoint not implemented for OpenAI backend", + ) + .into_response() + } +} diff --git a/sgl-router/src/routers/mod.rs b/sgl-router/src/routers/mod.rs index 76ef98821..e610fedb3 100644 --- a/sgl-router/src/routers/mod.rs +++ b/sgl-router/src/routers/mod.rs @@ -17,6 +17,8 @@ pub mod header_utils; pub mod http; pub use factory::RouterFactory; +// Re-export HTTP routers for convenience (keeps routers::openai_router path working) +pub use http::{openai_router, pd_router, pd_types, router}; /// Worker management trait for administrative operations /// diff --git a/sgl-router/tests/common/mock_mcp_server.rs b/sgl-router/tests/common/mock_mcp_server.rs index 6a2dd498d..daeec8001 100644 --- a/sgl-router/tests/common/mock_mcp_server.rs +++ b/sgl-router/tests/common/mock_mcp_server.rs @@ -63,10 +63,7 @@ impl ServerHandler for MockSearchServer { ServerInfo { protocol_version: ProtocolVersion::V_2024_11_05, capabilities: ServerCapabilities::builder().enable_tools().build(), - server_info: Implementation { - name: "Mock MCP Server".to_string(), - version: "1.0.0".to_string(), - }, + server_info: Implementation::from_build_env(), instructions: Some("Mock server for testing".to_string()), } } diff --git a/sgl-router/tests/common/mock_openai_server.rs b/sgl-router/tests/common/mock_openai_server.rs new file mode 100644 index 000000000..643fd5e98 --- /dev/null +++ b/sgl-router/tests/common/mock_openai_server.rs @@ -0,0 +1,238 @@ +//! Mock servers for testing + +#![allow(dead_code)] + +use axum::{ + body::Body, + extract::{Request, State}, + http::{HeaderValue, StatusCode}, + response::sse::{Event, KeepAlive}, + response::{IntoResponse, Response, Sse}, + routing::post, + Json, Router, +}; +use futures_util::stream::{self, StreamExt}; +use serde_json::json; +use std::net::SocketAddr; +use std::sync::Arc; +use tokio::net::TcpListener; + +/// Mock OpenAI API server for testing +pub struct MockOpenAIServer { + addr: SocketAddr, + _handle: tokio::task::JoinHandle<()>, +} + +#[derive(Clone)] +struct MockServerState { + require_auth: bool, + expected_auth: Option, +} + +impl MockOpenAIServer { + /// Create and start a new mock OpenAI server + pub async fn new() -> Self { + Self::new_with_auth(None).await + } + + /// Create and start a new mock OpenAI server with optional auth requirement + pub async fn new_with_auth(expected_auth: Option) -> Self { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + + let state = Arc::new(MockServerState { + require_auth: expected_auth.is_some(), + expected_auth, + }); + + let app = Router::new() + .route("/v1/chat/completions", post(mock_chat_completions)) + .route("/v1/completions", post(mock_completions)) + .route("/v1/models", post(mock_models).get(mock_models)) + .with_state(state); + + let handle = tokio::spawn(async move { + axum::serve(listener, app).await.unwrap(); + }); + + // Give the server a moment to start + tokio::time::sleep(tokio::time::Duration::from_millis(10)).await; + + Self { + addr, + _handle: handle, + } + } + + /// Get the base URL for this mock server + pub fn base_url(&self) -> String { + format!("http://{}", self.addr) + } +} + +/// Mock chat completions endpoint +async fn mock_chat_completions(req: Request) -> Response { + let (_, body) = req.into_parts(); + let body_bytes = match axum::body::to_bytes(body, usize::MAX).await { + Ok(bytes) => bytes, + Err(_) => return StatusCode::BAD_REQUEST.into_response(), + }; + + let request: serde_json::Value = match serde_json::from_slice(&body_bytes) { + Ok(req) => req, + Err(_) => return StatusCode::BAD_REQUEST.into_response(), + }; + + // Extract model from request or use default (owned String to satisfy 'static in stream) + let model: String = request + .get("model") + .and_then(|v| v.as_str()) + .unwrap_or("gpt-3.5-turbo") + .to_string(); + + // If stream requested, return SSE + let is_stream = request + .get("stream") + .and_then(|v| v.as_bool()) + .unwrap_or(false); + + if is_stream { + let created = 1677652288u64; + // Single chunk then [DONE] + let model_chunk = model.clone(); + let event_stream = stream::once(async move { + let chunk = json!({ + "id": "chatcmpl-123456789", + "object": "chat.completion.chunk", + "created": created, + "model": model_chunk, + "choices": [{ + "index": 0, + "delta": { + "content": "Hello!" + }, + "finish_reason": null + }] + }); + Ok::<_, std::convert::Infallible>(Event::default().data(chunk.to_string())) + }) + .chain(stream::once(async { Ok(Event::default().data("[DONE]")) })); + + Sse::new(event_stream) + .keep_alive(KeepAlive::default()) + .into_response() + } else { + // Create a mock non-streaming response + let response = json!({ + "id": "chatcmpl-123456789", + "object": "chat.completion", + "created": 1677652288, + "model": model, + "choices": [{ + "index": 0, + "message": { + "role": "assistant", + "content": "Hello! I'm a mock OpenAI assistant. How can I help you today?" + }, + "finish_reason": "stop" + }], + "usage": { + "prompt_tokens": 9, + "completion_tokens": 12, + "total_tokens": 21 + } + }); + + Json(response).into_response() + } +} + +/// Mock completions endpoint (legacy) +async fn mock_completions(req: Request) -> Response { + let (_, body) = req.into_parts(); + let body_bytes = match axum::body::to_bytes(body, usize::MAX).await { + Ok(bytes) => bytes, + Err(_) => return StatusCode::BAD_REQUEST.into_response(), + }; + + let request: serde_json::Value = match serde_json::from_slice(&body_bytes) { + Ok(req) => req, + Err(_) => return StatusCode::BAD_REQUEST.into_response(), + }; + + let model = request["model"].as_str().unwrap_or("text-davinci-003"); + + let response = json!({ + "id": "cmpl-123456789", + "object": "text_completion", + "created": 1677652288, + "model": model, + "choices": [{ + "text": " This is a mock completion response.", + "index": 0, + "logprobs": null, + "finish_reason": "stop" + }], + "usage": { + "prompt_tokens": 5, + "completion_tokens": 7, + "total_tokens": 12 + } + }); + + Json(response).into_response() +} + +/// Mock models endpoint +async fn mock_models(State(state): State>, req: Request) -> Response { + // Optionally enforce Authorization header + if state.require_auth { + let auth = req + .headers() + .get("authorization") + .or_else(|| req.headers().get("Authorization")) + .and_then(|v| v.to_str().ok()) + .map(|s| s.to_string()); + let auth_ok = match (&state.expected_auth, auth) { + (Some(expected), Some(got)) => &got == expected, + (None, Some(_)) => true, + _ => false, + }; + if !auth_ok { + let mut response = Response::new(Body::from( + json!({ + "error": { + "message": "Unauthorized", + "type": "invalid_request_error" + } + }) + .to_string(), + )); + *response.status_mut() = StatusCode::UNAUTHORIZED; + response + .headers_mut() + .insert("WWW-Authenticate", HeaderValue::from_static("Bearer")); + return response; + } + } + + let response = json!({ + "object": "list", + "data": [ + { + "id": "gpt-4", + "object": "model", + "created": 1677610602, + "owned_by": "openai" + }, + { + "id": "gpt-3.5-turbo", + "object": "model", + "created": 1677610602, + "owned_by": "openai" + } + ] + }); + + Json(response).into_response() +} diff --git a/sgl-router/tests/common/mock_worker.rs b/sgl-router/tests/common/mock_worker.rs old mode 100644 new mode 100755 diff --git a/sgl-router/tests/common/mod.rs b/sgl-router/tests/common/mod.rs index 553371fbc..0170cd765 100644 --- a/sgl-router/tests/common/mod.rs +++ b/sgl-router/tests/common/mod.rs @@ -2,6 +2,7 @@ #![allow(dead_code)] pub mod mock_mcp_server; +pub mod mock_openai_server; pub mod mock_worker; pub mod test_app; diff --git a/sgl-router/tests/test_openai_routing.rs b/sgl-router/tests/test_openai_routing.rs new file mode 100644 index 000000000..ec38a6dd5 --- /dev/null +++ b/sgl-router/tests/test_openai_routing.rs @@ -0,0 +1,419 @@ +//! Comprehensive integration tests for OpenAI backend functionality + +use axum::{ + body::Body, + extract::Request, + http::{Method, StatusCode}, + routing::post, + Router, +}; +use serde_json::json; +use sglang_router_rs::{ + config::{RouterConfig, RoutingMode}, + protocols::spec::{ + ChatCompletionRequest, ChatMessage, CompletionRequest, GenerateRequest, UserMessageContent, + }, + routers::{openai_router::OpenAIRouter, RouterTrait}, +}; +use std::sync::Arc; +use tower::ServiceExt; + +mod common; +use common::mock_openai_server::MockOpenAIServer; + +/// Helper function to create a minimal chat completion request for testing +fn create_minimal_chat_request() -> ChatCompletionRequest { + let val = json!({ + "model": "gpt-3.5-turbo", + "messages": [ + {"role": "user", "content": "Hello"} + ], + "max_tokens": 100 + }); + serde_json::from_value(val).unwrap() +} + +/// Helper function to create a minimal completion request for testing +fn create_minimal_completion_request() -> CompletionRequest { + CompletionRequest { + model: "gpt-3.5-turbo".to_string(), + prompt: sglang_router_rs::protocols::spec::StringOrArray::String("Hello".to_string()), + suffix: None, + max_tokens: Some(100), + temperature: None, + top_p: None, + n: None, + stream: false, + stream_options: None, + logprobs: None, + echo: false, + stop: None, + presence_penalty: None, + frequency_penalty: None, + best_of: None, + logit_bias: None, + user: None, + seed: None, + top_k: None, + min_p: None, + min_tokens: None, + repetition_penalty: None, + regex: None, + ebnf: None, + json_schema: None, + stop_token_ids: None, + no_stop_trim: false, + ignore_eos: false, + skip_special_tokens: true, + lora_path: None, + session_params: None, + return_hidden_states: false, + other: serde_json::Map::new(), + } +} + +// ============= Basic Unit Tests ============= + +/// Test basic OpenAI router creation and configuration +#[tokio::test] +async fn test_openai_router_creation() { + let router = OpenAIRouter::new("https://api.openai.com".to_string(), None).await; + + assert!(router.is_ok(), "Router creation should succeed"); + + let router = router.unwrap(); + assert_eq!(router.router_type(), "openai"); + assert!(!router.is_pd_mode()); +} + +/// Test health endpoints +#[tokio::test] +async fn test_openai_router_health() { + let router = OpenAIRouter::new("https://api.openai.com".to_string(), None) + .await + .unwrap(); + + let req = Request::builder() + .method(Method::GET) + .uri("/health") + .body(Body::empty()) + .unwrap(); + + let response = router.health(req).await; + assert_eq!(response.status(), StatusCode::OK); +} + +/// Test server info endpoint +#[tokio::test] +async fn test_openai_router_server_info() { + let router = OpenAIRouter::new("https://api.openai.com".to_string(), None) + .await + .unwrap(); + + let req = Request::builder() + .method(Method::GET) + .uri("/info") + .body(Body::empty()) + .unwrap(); + + let response = router.get_server_info(req).await; + assert_eq!(response.status(), StatusCode::OK); + + let (_, body) = response.into_parts(); + let body_bytes = axum::body::to_bytes(body, usize::MAX).await.unwrap(); + let body_str = String::from_utf8(body_bytes.to_vec()).unwrap(); + + assert!(body_str.contains("openai")); +} + +/// Test models endpoint +#[tokio::test] +async fn test_openai_router_models() { + // Use mock server for deterministic models response + let mock_server = MockOpenAIServer::new().await; + let router = OpenAIRouter::new(mock_server.base_url(), None) + .await + .unwrap(); + + let req = Request::builder() + .method(Method::GET) + .uri("/models") + .body(Body::empty()) + .unwrap(); + + let response = router.get_models(req).await; + assert_eq!(response.status(), StatusCode::OK); + + let (_, body) = response.into_parts(); + let body_bytes = axum::body::to_bytes(body, usize::MAX).await.unwrap(); + let body_str = String::from_utf8(body_bytes.to_vec()).unwrap(); + let models: serde_json::Value = serde_json::from_str(&body_str).unwrap(); + + assert_eq!(models["object"], "list"); + assert!(models["data"].is_array()); +} + +/// Test router factory with OpenAI routing mode +#[tokio::test] +async fn test_router_factory_openai_mode() { + let routing_mode = RoutingMode::OpenAI { + worker_urls: vec!["https://api.openai.com".to_string()], + }; + + let router_config = + RouterConfig::new(routing_mode, sglang_router_rs::config::PolicyConfig::Random); + + let app_context = common::create_test_context(router_config); + + let router = sglang_router_rs::routers::RouterFactory::create_router(&app_context).await; + assert!( + router.is_ok(), + "Router factory should create OpenAI router successfully" + ); + + let router = router.unwrap(); + assert_eq!(router.router_type(), "openai"); +} + +/// Test that unsupported endpoints return proper error codes +#[tokio::test] +async fn test_unsupported_endpoints() { + let router = OpenAIRouter::new("https://api.openai.com".to_string(), None) + .await + .unwrap(); + + // Test generate endpoint (SGLang-specific, should not be supported) + let generate_request = GenerateRequest { + prompt: None, + text: Some("Hello world".to_string()), + input_ids: None, + parameters: None, + sampling_params: None, + stream: false, + return_logprob: false, + lora_path: None, + session_params: None, + return_hidden_states: false, + rid: None, + }; + + let response = router.route_generate(None, &generate_request).await; + assert_eq!(response.status(), StatusCode::NOT_IMPLEMENTED); + + // Test completion endpoint (should also not be supported) + let completion_request = create_minimal_completion_request(); + let response = router.route_completion(None, &completion_request).await; + assert_eq!(response.status(), StatusCode::NOT_IMPLEMENTED); +} + +// ============= Mock Server E2E Tests ============= + +/// Test chat completion with mock OpenAI server +#[tokio::test] +async fn test_openai_router_chat_completion_with_mock() { + // Start a mock OpenAI server + let mock_server = MockOpenAIServer::new().await; + let base_url = mock_server.base_url(); + + // Create router pointing to mock server + let router = OpenAIRouter::new(base_url, None).await.unwrap(); + + // Create a minimal chat completion request + let mut chat_request = create_minimal_chat_request(); + chat_request.messages = vec![ChatMessage::User { + role: "user".to_string(), + content: UserMessageContent::Text("Hello, how are you?".to_string()), + name: None, + }]; + chat_request.temperature = Some(0.7); + + // Route the request + let response = router.route_chat(None, &chat_request).await; + + // Should get a successful response from mock server + assert_eq!(response.status(), StatusCode::OK); + + let (_, body) = response.into_parts(); + let body_bytes = axum::body::to_bytes(body, usize::MAX).await.unwrap(); + let body_str = String::from_utf8(body_bytes.to_vec()).unwrap(); + let chat_response: serde_json::Value = serde_json::from_str(&body_str).unwrap(); + + // Verify it's a valid chat completion response + assert_eq!(chat_response["object"], "chat.completion"); + assert_eq!(chat_response["model"], "gpt-3.5-turbo"); + assert!(!chat_response["choices"].as_array().unwrap().is_empty()); +} + +/// Test full E2E flow with Axum server +#[tokio::test] +async fn test_openai_e2e_with_server() { + // Start mock OpenAI server + let mock_server = MockOpenAIServer::new().await; + let base_url = mock_server.base_url(); + + // Create router + let router = OpenAIRouter::new(base_url, None).await.unwrap(); + + // Create Axum app with chat completions endpoint + let app = Router::new().route( + "/v1/chat/completions", + post({ + let router = Arc::new(router); + move |req: Request| { + let router = router.clone(); + async move { + let (parts, body) = req.into_parts(); + let body_bytes = axum::body::to_bytes(body, usize::MAX).await.unwrap(); + let body_str = String::from_utf8(body_bytes.to_vec()).unwrap(); + + let chat_request: ChatCompletionRequest = + serde_json::from_str(&body_str).unwrap(); + + router.route_chat(Some(&parts.headers), &chat_request).await + } + } + }), + ); + + // Make a request to the server + let request = Request::builder() + .method(Method::POST) + .uri("/v1/chat/completions") + .header("content-type", "application/json") + .body(Body::from( + json!({ + "model": "gpt-3.5-turbo", + "messages": [ + { + "role": "user", + "content": "Hello, world!" + } + ], + "max_tokens": 100 + }) + .to_string(), + )) + .unwrap(); + + let response = app.oneshot(request).await.unwrap(); + assert_eq!(response.status(), StatusCode::OK); + + let body = axum::body::to_bytes(response.into_body(), usize::MAX) + .await + .unwrap(); + let response_json: serde_json::Value = serde_json::from_slice(&body).unwrap(); + + // Verify the response structure + assert_eq!(response_json["object"], "chat.completion"); + assert_eq!(response_json["model"], "gpt-3.5-turbo"); + assert!(!response_json["choices"].as_array().unwrap().is_empty()); +} + +/// Test streaming chat completions pass-through with mock server +#[tokio::test] +async fn test_openai_router_chat_streaming_with_mock() { + let mock_server = MockOpenAIServer::new().await; + let base_url = mock_server.base_url(); + let router = OpenAIRouter::new(base_url, None).await.unwrap(); + + // Build a streaming chat request + let val = json!({ + "model": "gpt-3.5-turbo", + "messages": [ + {"role": "user", "content": "Hello"} + ], + "max_tokens": 10, + "stream": true + }); + let chat_request: ChatCompletionRequest = serde_json::from_value(val).unwrap(); + + let response = router.route_chat(None, &chat_request).await; + assert_eq!(response.status(), StatusCode::OK); + + // Should be SSE + let headers = response.headers(); + let ct = headers + .get("content-type") + .unwrap() + .to_str() + .unwrap() + .to_ascii_lowercase(); + assert!(ct.contains("text/event-stream")); + + // Read entire stream body and assert chunks + DONE + let body = axum::body::to_bytes(response.into_body(), usize::MAX) + .await + .unwrap(); + let text = String::from_utf8(body.to_vec()).unwrap(); + assert!(text.contains("chat.completion.chunk")); + assert!(text.contains("[DONE]")); +} + +/// Test circuit breaker functionality +#[tokio::test] +async fn test_openai_router_circuit_breaker() { + // Create router with circuit breaker config + let cb_config = sglang_router_rs::config::CircuitBreakerConfig { + failure_threshold: 2, + success_threshold: 1, + timeout_duration_secs: 1, + window_duration_secs: 10, + }; + + let router = OpenAIRouter::new( + "http://invalid-url-that-will-fail".to_string(), + Some(cb_config), + ) + .await + .unwrap(); + + let chat_request = create_minimal_chat_request(); + + // First few requests should fail and record failures + for _ in 0..3 { + let response = router.route_chat(None, &chat_request).await; + // Should get either an error or circuit breaker response + assert!( + response.status() == StatusCode::INTERNAL_SERVER_ERROR + || response.status() == StatusCode::SERVICE_UNAVAILABLE + ); + } +} + +/// Test that Authorization header is forwarded in /v1/models +#[tokio::test] +async fn test_openai_router_models_auth_forwarding() { + // Start a mock server that requires Authorization + let expected_auth = "Bearer test-token".to_string(); + let mock_server = MockOpenAIServer::new_with_auth(Some(expected_auth.clone())).await; + let router = OpenAIRouter::new(mock_server.base_url(), None) + .await + .unwrap(); + + // 1) Without auth header -> expect 401 + let req = Request::builder() + .method(Method::GET) + .uri("/models") + .body(Body::empty()) + .unwrap(); + + let response = router.get_models(req).await; + assert_eq!(response.status(), StatusCode::UNAUTHORIZED); + + // 2) With auth header -> expect 200 + let req = Request::builder() + .method(Method::GET) + .uri("/models") + .header("Authorization", expected_auth) + .body(Body::empty()) + .unwrap(); + + let response = router.get_models(req).await; + assert_eq!(response.status(), StatusCode::OK); + + let (_, body) = response.into_parts(); + let body_bytes = axum::body::to_bytes(body, usize::MAX).await.unwrap(); + let body_str = String::from_utf8(body_bytes.to_vec()).unwrap(); + let models: serde_json::Value = serde_json::from_str(&body_str).unwrap(); + assert_eq!(models["object"], "list"); +}