[router] Add OpenAI backend support - core function (#10254)
This commit is contained in:
@@ -101,6 +101,11 @@ pub enum RoutingMode {
|
|||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
decode_policy: Option<PolicyConfig>,
|
decode_policy: Option<PolicyConfig>,
|
||||||
},
|
},
|
||||||
|
#[serde(rename = "openai")]
|
||||||
|
OpenAI {
|
||||||
|
/// OpenAI-compatible API base(s), provided via worker URLs
|
||||||
|
worker_urls: Vec<String>,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
impl RoutingMode {
|
impl RoutingMode {
|
||||||
@@ -116,6 +121,8 @@ impl RoutingMode {
|
|||||||
decode_urls,
|
decode_urls,
|
||||||
..
|
..
|
||||||
} => prefill_urls.len() + decode_urls.len(),
|
} => prefill_urls.len() + decode_urls.len(),
|
||||||
|
// OpenAI mode represents a single upstream
|
||||||
|
RoutingMode::OpenAI { .. } => 1,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -380,6 +387,7 @@ impl RouterConfig {
|
|||||||
match self.mode {
|
match self.mode {
|
||||||
RoutingMode::Regular { .. } => "regular",
|
RoutingMode::Regular { .. } => "regular",
|
||||||
RoutingMode::PrefillDecode { .. } => "prefill_decode",
|
RoutingMode::PrefillDecode { .. } => "prefill_decode",
|
||||||
|
RoutingMode::OpenAI { .. } => "openai",
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -95,6 +95,20 @@ impl ConfigValidator {
|
|||||||
Self::validate_policy(d_policy)?;
|
Self::validate_policy(d_policy)?;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
RoutingMode::OpenAI { worker_urls } => {
|
||||||
|
// Require exactly one worker URL for OpenAI router
|
||||||
|
if worker_urls.len() != 1 {
|
||||||
|
return Err(ConfigError::ValidationFailed {
|
||||||
|
reason: "OpenAI mode requires exactly one --worker-urls entry".to_string(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
// Validate URL format
|
||||||
|
if let Err(e) = url::Url::parse(&worker_urls[0]) {
|
||||||
|
return Err(ConfigError::ValidationFailed {
|
||||||
|
reason: format!("Invalid OpenAI worker URL '{}': {}", &worker_urls[0], e),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
@@ -243,6 +257,12 @@ impl ConfigValidator {
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
RoutingMode::OpenAI { .. } => {
|
||||||
|
// OpenAI mode doesn't use service discovery
|
||||||
|
return Err(ConfigError::ValidationFailed {
|
||||||
|
reason: "OpenAI mode does not support service discovery".to_string(),
|
||||||
|
});
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
use clap::{ArgAction, Parser};
|
use clap::{ArgAction, Parser, ValueEnum};
|
||||||
use sglang_router_rs::config::{
|
use sglang_router_rs::config::{
|
||||||
CircuitBreakerConfig, ConfigError, ConfigResult, ConnectionMode, DiscoveryConfig,
|
CircuitBreakerConfig, ConfigError, ConfigResult, ConnectionMode, DiscoveryConfig,
|
||||||
HealthCheckConfig, MetricsConfig, PolicyConfig, RetryConfig, RouterConfig, RoutingMode,
|
HealthCheckConfig, MetricsConfig, PolicyConfig, RetryConfig, RouterConfig, RoutingMode,
|
||||||
@@ -41,6 +41,33 @@ fn parse_prefill_args() -> Vec<(String, Option<u16>)> {
|
|||||||
prefill_entries
|
prefill_entries
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Copy, Clone, Debug, Eq, PartialEq, ValueEnum)]
|
||||||
|
pub enum Backend {
|
||||||
|
#[value(name = "sglang")]
|
||||||
|
Sglang,
|
||||||
|
#[value(name = "vllm")]
|
||||||
|
Vllm,
|
||||||
|
#[value(name = "trtllm")]
|
||||||
|
Trtllm,
|
||||||
|
#[value(name = "openai")]
|
||||||
|
Openai,
|
||||||
|
#[value(name = "anthropic")]
|
||||||
|
Anthropic,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::fmt::Display for Backend {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
let s = match self {
|
||||||
|
Backend::Sglang => "sglang",
|
||||||
|
Backend::Vllm => "vllm",
|
||||||
|
Backend::Trtllm => "trtllm",
|
||||||
|
Backend::Openai => "openai",
|
||||||
|
Backend::Anthropic => "anthropic",
|
||||||
|
};
|
||||||
|
write!(f, "{}", s)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Parser, Debug)]
|
#[derive(Parser, Debug)]
|
||||||
#[command(name = "sglang-router")]
|
#[command(name = "sglang-router")]
|
||||||
#[command(about = "SGLang Router - High-performance request distribution across worker nodes")]
|
#[command(about = "SGLang Router - High-performance request distribution across worker nodes")]
|
||||||
@@ -145,6 +172,10 @@ struct CliArgs {
|
|||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
api_key: Option<String>,
|
api_key: Option<String>,
|
||||||
|
|
||||||
|
/// Backend to route requests to (sglang, vllm, trtllm, openai, anthropic)
|
||||||
|
#[arg(long, value_enum, default_value_t = Backend::Sglang, alias = "runtime")]
|
||||||
|
backend: Backend,
|
||||||
|
|
||||||
/// Directory to store log files
|
/// Directory to store log files
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
log_dir: Option<String>,
|
log_dir: Option<String>,
|
||||||
@@ -339,6 +370,11 @@ impl CliArgs {
|
|||||||
RoutingMode::Regular {
|
RoutingMode::Regular {
|
||||||
worker_urls: vec![],
|
worker_urls: vec![],
|
||||||
}
|
}
|
||||||
|
} else if matches!(self.backend, Backend::Openai) {
|
||||||
|
// OpenAI backend mode - use worker_urls as base(s)
|
||||||
|
RoutingMode::OpenAI {
|
||||||
|
worker_urls: self.worker_urls.clone(),
|
||||||
|
}
|
||||||
} else if self.pd_disaggregation {
|
} else if self.pd_disaggregation {
|
||||||
let decode_urls = self.decode.clone();
|
let decode_urls = self.decode.clone();
|
||||||
|
|
||||||
@@ -409,8 +445,14 @@ impl CliArgs {
|
|||||||
}
|
}
|
||||||
all_urls.extend(decode_urls.clone());
|
all_urls.extend(decode_urls.clone());
|
||||||
}
|
}
|
||||||
|
RoutingMode::OpenAI { .. } => {
|
||||||
|
// For connection-mode detection, skip URLs; OpenAI forces HTTP below.
|
||||||
|
}
|
||||||
}
|
}
|
||||||
let connection_mode = Self::determine_connection_mode(&all_urls);
|
let connection_mode = match &mode {
|
||||||
|
RoutingMode::OpenAI { .. } => ConnectionMode::Http,
|
||||||
|
_ => Self::determine_connection_mode(&all_urls),
|
||||||
|
};
|
||||||
|
|
||||||
// Build RouterConfig
|
// Build RouterConfig
|
||||||
Ok(RouterConfig {
|
Ok(RouterConfig {
|
||||||
@@ -543,16 +585,28 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|||||||
// Print startup info
|
// Print startup info
|
||||||
println!("SGLang Router starting...");
|
println!("SGLang Router starting...");
|
||||||
println!("Host: {}:{}", cli_args.host, cli_args.port);
|
println!("Host: {}:{}", cli_args.host, cli_args.port);
|
||||||
println!(
|
let mode_str = if cli_args.enable_igw {
|
||||||
"Mode: {}",
|
"IGW (Inference Gateway)".to_string()
|
||||||
if cli_args.enable_igw {
|
} else if matches!(cli_args.backend, Backend::Openai) {
|
||||||
"IGW (Inference Gateway)"
|
"OpenAI Backend".to_string()
|
||||||
} else if cli_args.pd_disaggregation {
|
} else if cli_args.pd_disaggregation {
|
||||||
"PD Disaggregated"
|
"PD Disaggregated".to_string()
|
||||||
} else {
|
} else {
|
||||||
"Regular"
|
format!("Regular ({})", cli_args.backend)
|
||||||
|
};
|
||||||
|
println!("Mode: {}", mode_str);
|
||||||
|
|
||||||
|
// Warn for runtimes that are parsed but not yet implemented
|
||||||
|
match cli_args.backend {
|
||||||
|
Backend::Vllm | Backend::Trtllm | Backend::Anthropic => {
|
||||||
|
println!(
|
||||||
|
"WARNING: runtime '{}' not implemented yet; falling back to regular routing. \
|
||||||
|
Provide --worker-urls or PD flags as usual.",
|
||||||
|
cli_args.backend
|
||||||
|
);
|
||||||
}
|
}
|
||||||
);
|
Backend::Sglang | Backend::Openai => {}
|
||||||
|
}
|
||||||
|
|
||||||
if !cli_args.enable_igw {
|
if !cli_args.enable_igw {
|
||||||
println!("Policy: {}", cli_args.policy);
|
println!("Policy: {}", cli_args.policy);
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
//! Factory for creating router instances
|
//! Factory for creating router instances
|
||||||
|
|
||||||
use super::{
|
use super::{
|
||||||
http::{pd_router::PDRouter, router::Router},
|
http::{openai_router::OpenAIRouter, pd_router::PDRouter, router::Router},
|
||||||
RouterTrait,
|
RouterTrait,
|
||||||
};
|
};
|
||||||
use crate::config::{ConnectionMode, PolicyConfig, RoutingMode};
|
use crate::config::{ConnectionMode, PolicyConfig, RoutingMode};
|
||||||
@@ -44,6 +44,9 @@ impl RouterFactory {
|
|||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
}
|
}
|
||||||
|
RoutingMode::OpenAI { .. } => {
|
||||||
|
Err("OpenAI mode requires HTTP connection_mode".to_string())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
ConnectionMode::Http => {
|
ConnectionMode::Http => {
|
||||||
@@ -69,6 +72,9 @@ impl RouterFactory {
|
|||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
}
|
}
|
||||||
|
RoutingMode::OpenAI { worker_urls, .. } => {
|
||||||
|
Self::create_openai_router(worker_urls.clone(), ctx).await
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -164,6 +170,23 @@ impl RouterFactory {
|
|||||||
Ok(Box::new(router))
|
Ok(Box::new(router))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Create an OpenAI router
|
||||||
|
async fn create_openai_router(
|
||||||
|
worker_urls: Vec<String>,
|
||||||
|
ctx: &Arc<AppContext>,
|
||||||
|
) -> Result<Box<dyn RouterTrait>, String> {
|
||||||
|
// Use the first worker URL as the OpenAI-compatible base
|
||||||
|
let base_url = worker_urls
|
||||||
|
.first()
|
||||||
|
.cloned()
|
||||||
|
.ok_or_else(|| "OpenAI mode requires at least one worker URL".to_string())?;
|
||||||
|
|
||||||
|
let router =
|
||||||
|
OpenAIRouter::new(base_url, Some(ctx.router_config.circuit_breaker.clone())).await?;
|
||||||
|
|
||||||
|
Ok(Box::new(router))
|
||||||
|
}
|
||||||
|
|
||||||
/// Create an IGW router (placeholder for future implementation)
|
/// Create an IGW router (placeholder for future implementation)
|
||||||
async fn create_igw_router(_ctx: &Arc<AppContext>) -> Result<Box<dyn RouterTrait>, String> {
|
async fn create_igw_router(_ctx: &Arc<AppContext>) -> Result<Box<dyn RouterTrait>, String> {
|
||||||
// For now, return an error indicating IGW is not yet implemented
|
// For now, return an error indicating IGW is not yet implemented
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
//! HTTP router implementations
|
//! HTTP router implementations
|
||||||
|
|
||||||
|
pub mod openai_router;
|
||||||
pub mod pd_router;
|
pub mod pd_router;
|
||||||
pub mod pd_types;
|
pub mod pd_types;
|
||||||
pub mod router;
|
pub mod router;
|
||||||
|
|||||||
379
sgl-router/src/routers/http/openai_router.rs
Normal file
379
sgl-router/src/routers/http/openai_router.rs
Normal file
@@ -0,0 +1,379 @@
|
|||||||
|
//! OpenAI router implementation (reqwest-based)
|
||||||
|
|
||||||
|
use crate::config::CircuitBreakerConfig;
|
||||||
|
use crate::core::{CircuitBreaker, CircuitBreakerConfig as CoreCircuitBreakerConfig};
|
||||||
|
use crate::protocols::spec::{ChatCompletionRequest, CompletionRequest, GenerateRequest};
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use axum::{
|
||||||
|
body::Body,
|
||||||
|
extract::Request,
|
||||||
|
http::{header::CONTENT_TYPE, HeaderMap, HeaderValue, StatusCode},
|
||||||
|
response::{IntoResponse, Response},
|
||||||
|
};
|
||||||
|
use futures_util::StreamExt;
|
||||||
|
use std::{
|
||||||
|
any::Any,
|
||||||
|
sync::atomic::{AtomicBool, Ordering},
|
||||||
|
};
|
||||||
|
|
||||||
|
/// Router for OpenAI backend
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct OpenAIRouter {
|
||||||
|
/// HTTP client for upstream OpenAI-compatible API
|
||||||
|
client: reqwest::Client,
|
||||||
|
/// Base URL for identification (no trailing slash)
|
||||||
|
base_url: String,
|
||||||
|
/// Circuit breaker
|
||||||
|
circuit_breaker: CircuitBreaker,
|
||||||
|
/// Health status
|
||||||
|
healthy: AtomicBool,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl OpenAIRouter {
|
||||||
|
/// Create a new OpenAI router
|
||||||
|
pub async fn new(
|
||||||
|
base_url: String,
|
||||||
|
circuit_breaker_config: Option<CircuitBreakerConfig>,
|
||||||
|
) -> Result<Self, String> {
|
||||||
|
let client = reqwest::Client::builder()
|
||||||
|
.timeout(std::time::Duration::from_secs(300))
|
||||||
|
.build()
|
||||||
|
.map_err(|e| format!("Failed to create HTTP client: {}", e))?;
|
||||||
|
|
||||||
|
let base_url = base_url.trim_end_matches('/').to_string();
|
||||||
|
|
||||||
|
// Convert circuit breaker config
|
||||||
|
let core_cb_config = circuit_breaker_config
|
||||||
|
.map(|cb| CoreCircuitBreakerConfig {
|
||||||
|
failure_threshold: cb.failure_threshold,
|
||||||
|
success_threshold: cb.success_threshold,
|
||||||
|
timeout_duration: std::time::Duration::from_secs(cb.timeout_duration_secs),
|
||||||
|
window_duration: std::time::Duration::from_secs(cb.window_duration_secs),
|
||||||
|
})
|
||||||
|
.unwrap_or_default();
|
||||||
|
|
||||||
|
let circuit_breaker = CircuitBreaker::with_config(core_cb_config);
|
||||||
|
|
||||||
|
Ok(Self {
|
||||||
|
client,
|
||||||
|
base_url,
|
||||||
|
circuit_breaker,
|
||||||
|
healthy: AtomicBool::new(true),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl super::super::WorkerManagement for OpenAIRouter {
|
||||||
|
async fn add_worker(&self, _worker_url: &str) -> Result<String, String> {
|
||||||
|
Err("Cannot add workers to OpenAI router".to_string())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn remove_worker(&self, _worker_url: &str) {
|
||||||
|
// No-op for OpenAI router
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_worker_urls(&self) -> Vec<String> {
|
||||||
|
vec![self.base_url.clone()]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl super::super::RouterTrait for OpenAIRouter {
|
||||||
|
fn as_any(&self) -> &dyn Any {
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn health(&self, _req: Request<Body>) -> Response {
|
||||||
|
// Simple upstream probe: GET {base}/v1/models without auth
|
||||||
|
let url = format!("{}/v1/models", self.base_url);
|
||||||
|
match self
|
||||||
|
.client
|
||||||
|
.get(&url)
|
||||||
|
.timeout(std::time::Duration::from_secs(2))
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
{
|
||||||
|
Ok(resp) => {
|
||||||
|
let code = resp.status();
|
||||||
|
// Treat success and auth-required as healthy (endpoint reachable)
|
||||||
|
if code.is_success() || code.as_u16() == 401 || code.as_u16() == 403 {
|
||||||
|
(StatusCode::OK, "OK").into_response()
|
||||||
|
} else {
|
||||||
|
(
|
||||||
|
StatusCode::SERVICE_UNAVAILABLE,
|
||||||
|
format!("Upstream status: {}", code),
|
||||||
|
)
|
||||||
|
.into_response()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(e) => (
|
||||||
|
StatusCode::SERVICE_UNAVAILABLE,
|
||||||
|
format!("Upstream error: {}", e),
|
||||||
|
)
|
||||||
|
.into_response(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn health_generate(&self, _req: Request<Body>) -> Response {
|
||||||
|
// For OpenAI, health_generate is the same as health
|
||||||
|
self.health(_req).await
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn get_server_info(&self, _req: Request<Body>) -> Response {
|
||||||
|
let info = serde_json::json!({
|
||||||
|
"router_type": "openai",
|
||||||
|
"workers": 1,
|
||||||
|
"base_url": &self.base_url
|
||||||
|
});
|
||||||
|
(StatusCode::OK, info.to_string()).into_response()
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn get_models(&self, req: Request<Body>) -> Response {
|
||||||
|
// Proxy to upstream /v1/models; forward Authorization header if provided
|
||||||
|
let headers = req.headers();
|
||||||
|
|
||||||
|
let mut upstream = self.client.get(format!("{}/v1/models", self.base_url));
|
||||||
|
|
||||||
|
if let Some(auth) = headers
|
||||||
|
.get("authorization")
|
||||||
|
.or_else(|| headers.get("Authorization"))
|
||||||
|
{
|
||||||
|
upstream = upstream.header("Authorization", auth);
|
||||||
|
}
|
||||||
|
|
||||||
|
match upstream.send().await {
|
||||||
|
Ok(res) => {
|
||||||
|
let status = StatusCode::from_u16(res.status().as_u16())
|
||||||
|
.unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
|
||||||
|
let content_type = res.headers().get(CONTENT_TYPE).cloned();
|
||||||
|
match res.bytes().await {
|
||||||
|
Ok(body) => {
|
||||||
|
let mut response = Response::new(axum::body::Body::from(body));
|
||||||
|
*response.status_mut() = status;
|
||||||
|
if let Some(ct) = content_type {
|
||||||
|
response.headers_mut().insert(CONTENT_TYPE, ct);
|
||||||
|
}
|
||||||
|
response
|
||||||
|
}
|
||||||
|
Err(e) => (
|
||||||
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
|
format!("Failed to read upstream response: {}", e),
|
||||||
|
)
|
||||||
|
.into_response(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(e) => (
|
||||||
|
StatusCode::BAD_GATEWAY,
|
||||||
|
format!("Failed to contact upstream: {}", e),
|
||||||
|
)
|
||||||
|
.into_response(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn get_model_info(&self, _req: Request<Body>) -> Response {
|
||||||
|
// Not directly supported without model param; return 501
|
||||||
|
(
|
||||||
|
StatusCode::NOT_IMPLEMENTED,
|
||||||
|
"get_model_info not implemented for OpenAI router",
|
||||||
|
)
|
||||||
|
.into_response()
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn route_generate(
|
||||||
|
&self,
|
||||||
|
_headers: Option<&HeaderMap>,
|
||||||
|
_body: &GenerateRequest,
|
||||||
|
) -> Response {
|
||||||
|
// Generate endpoint is SGLang-specific, not supported for OpenAI backend
|
||||||
|
(
|
||||||
|
StatusCode::NOT_IMPLEMENTED,
|
||||||
|
"Generate endpoint not supported for OpenAI backend",
|
||||||
|
)
|
||||||
|
.into_response()
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn route_chat(
|
||||||
|
&self,
|
||||||
|
headers: Option<&HeaderMap>,
|
||||||
|
body: &ChatCompletionRequest,
|
||||||
|
) -> Response {
|
||||||
|
if !self.circuit_breaker.can_execute() {
|
||||||
|
return (StatusCode::SERVICE_UNAVAILABLE, "Circuit breaker open").into_response();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Serialize request body, removing SGLang-only fields
|
||||||
|
let mut payload = match serde_json::to_value(body) {
|
||||||
|
Ok(v) => v,
|
||||||
|
Err(e) => {
|
||||||
|
return (
|
||||||
|
StatusCode::BAD_REQUEST,
|
||||||
|
format!("Failed to serialize request: {}", e),
|
||||||
|
)
|
||||||
|
.into_response();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
if let Some(obj) = payload.as_object_mut() {
|
||||||
|
for key in [
|
||||||
|
"top_k",
|
||||||
|
"min_p",
|
||||||
|
"min_tokens",
|
||||||
|
"regex",
|
||||||
|
"ebnf",
|
||||||
|
"stop_token_ids",
|
||||||
|
"no_stop_trim",
|
||||||
|
"ignore_eos",
|
||||||
|
"continue_final_message",
|
||||||
|
"skip_special_tokens",
|
||||||
|
"lora_path",
|
||||||
|
"session_params",
|
||||||
|
"separate_reasoning",
|
||||||
|
"stream_reasoning",
|
||||||
|
"chat_template_kwargs",
|
||||||
|
"return_hidden_states",
|
||||||
|
"repetition_penalty",
|
||||||
|
] {
|
||||||
|
obj.remove(key);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let url = format!("{}/v1/chat/completions", self.base_url);
|
||||||
|
let mut req = self.client.post(&url).json(&payload);
|
||||||
|
|
||||||
|
// Forward Authorization header if provided
|
||||||
|
if let Some(h) = headers {
|
||||||
|
if let Some(auth) = h.get("authorization").or_else(|| h.get("Authorization")) {
|
||||||
|
req = req.header("Authorization", auth);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Accept SSE when stream=true
|
||||||
|
if body.stream {
|
||||||
|
req = req.header("Accept", "text/event-stream");
|
||||||
|
}
|
||||||
|
|
||||||
|
let resp = match req.send().await {
|
||||||
|
Ok(r) => r,
|
||||||
|
Err(e) => {
|
||||||
|
self.circuit_breaker.record_failure();
|
||||||
|
return (
|
||||||
|
StatusCode::SERVICE_UNAVAILABLE,
|
||||||
|
format!("Failed to contact upstream: {}", e),
|
||||||
|
)
|
||||||
|
.into_response();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let status = StatusCode::from_u16(resp.status().as_u16())
|
||||||
|
.unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
|
||||||
|
|
||||||
|
if !body.stream {
|
||||||
|
// Capture Content-Type before consuming response body
|
||||||
|
let content_type = resp.headers().get(CONTENT_TYPE).cloned();
|
||||||
|
match resp.bytes().await {
|
||||||
|
Ok(body) => {
|
||||||
|
self.circuit_breaker.record_success();
|
||||||
|
let mut response = Response::new(axum::body::Body::from(body));
|
||||||
|
*response.status_mut() = status;
|
||||||
|
if let Some(ct) = content_type {
|
||||||
|
response.headers_mut().insert(CONTENT_TYPE, ct);
|
||||||
|
}
|
||||||
|
response
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
self.circuit_breaker.record_failure();
|
||||||
|
(
|
||||||
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
|
format!("Failed to read response: {}", e),
|
||||||
|
)
|
||||||
|
.into_response()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Stream SSE bytes to client
|
||||||
|
let stream = resp.bytes_stream();
|
||||||
|
let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
|
||||||
|
tokio::spawn(async move {
|
||||||
|
let mut s = stream;
|
||||||
|
while let Some(chunk) = s.next().await {
|
||||||
|
match chunk {
|
||||||
|
Ok(bytes) => {
|
||||||
|
if tx.send(Ok(bytes)).is_err() {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
let _ = tx.send(Err(format!("Stream error: {}", e)));
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
let mut response = Response::new(Body::from_stream(
|
||||||
|
tokio_stream::wrappers::UnboundedReceiverStream::new(rx),
|
||||||
|
));
|
||||||
|
*response.status_mut() = status;
|
||||||
|
response
|
||||||
|
.headers_mut()
|
||||||
|
.insert(CONTENT_TYPE, HeaderValue::from_static("text/event-stream"));
|
||||||
|
response
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn route_completion(
|
||||||
|
&self,
|
||||||
|
_headers: Option<&HeaderMap>,
|
||||||
|
_body: &CompletionRequest,
|
||||||
|
) -> Response {
|
||||||
|
// Completion endpoint not implemented for OpenAI backend
|
||||||
|
(
|
||||||
|
StatusCode::NOT_IMPLEMENTED,
|
||||||
|
"Completion endpoint not implemented for OpenAI backend",
|
||||||
|
)
|
||||||
|
.into_response()
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn flush_cache(&self) -> Response {
|
||||||
|
(
|
||||||
|
StatusCode::NOT_IMPLEMENTED,
|
||||||
|
"flush_cache not supported for OpenAI router",
|
||||||
|
)
|
||||||
|
.into_response()
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn get_worker_loads(&self) -> Response {
|
||||||
|
(
|
||||||
|
StatusCode::NOT_IMPLEMENTED,
|
||||||
|
"get_worker_loads not supported for OpenAI router",
|
||||||
|
)
|
||||||
|
.into_response()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn router_type(&self) -> &'static str {
|
||||||
|
"openai"
|
||||||
|
}
|
||||||
|
|
||||||
|
fn readiness(&self) -> Response {
|
||||||
|
if self.healthy.load(Ordering::Acquire) && self.circuit_breaker.can_execute() {
|
||||||
|
(StatusCode::OK, "Ready").into_response()
|
||||||
|
} else {
|
||||||
|
(StatusCode::SERVICE_UNAVAILABLE, "Not ready").into_response()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn route_embeddings(&self, _headers: Option<&HeaderMap>, _body: Body) -> Response {
|
||||||
|
(
|
||||||
|
StatusCode::NOT_IMPLEMENTED,
|
||||||
|
"Embeddings endpoint not implemented for OpenAI backend",
|
||||||
|
)
|
||||||
|
.into_response()
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn route_rerank(&self, _headers: Option<&HeaderMap>, _body: Body) -> Response {
|
||||||
|
(
|
||||||
|
StatusCode::NOT_IMPLEMENTED,
|
||||||
|
"Rerank endpoint not implemented for OpenAI backend",
|
||||||
|
)
|
||||||
|
.into_response()
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -17,6 +17,8 @@ pub mod header_utils;
|
|||||||
pub mod http;
|
pub mod http;
|
||||||
|
|
||||||
pub use factory::RouterFactory;
|
pub use factory::RouterFactory;
|
||||||
|
// Re-export HTTP routers for convenience (keeps routers::openai_router path working)
|
||||||
|
pub use http::{openai_router, pd_router, pd_types, router};
|
||||||
|
|
||||||
/// Worker management trait for administrative operations
|
/// Worker management trait for administrative operations
|
||||||
///
|
///
|
||||||
|
|||||||
@@ -63,10 +63,7 @@ impl ServerHandler for MockSearchServer {
|
|||||||
ServerInfo {
|
ServerInfo {
|
||||||
protocol_version: ProtocolVersion::V_2024_11_05,
|
protocol_version: ProtocolVersion::V_2024_11_05,
|
||||||
capabilities: ServerCapabilities::builder().enable_tools().build(),
|
capabilities: ServerCapabilities::builder().enable_tools().build(),
|
||||||
server_info: Implementation {
|
server_info: Implementation::from_build_env(),
|
||||||
name: "Mock MCP Server".to_string(),
|
|
||||||
version: "1.0.0".to_string(),
|
|
||||||
},
|
|
||||||
instructions: Some("Mock server for testing".to_string()),
|
instructions: Some("Mock server for testing".to_string()),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
238
sgl-router/tests/common/mock_openai_server.rs
Normal file
238
sgl-router/tests/common/mock_openai_server.rs
Normal file
@@ -0,0 +1,238 @@
|
|||||||
|
//! Mock servers for testing
|
||||||
|
|
||||||
|
#![allow(dead_code)]
|
||||||
|
|
||||||
|
use axum::{
|
||||||
|
body::Body,
|
||||||
|
extract::{Request, State},
|
||||||
|
http::{HeaderValue, StatusCode},
|
||||||
|
response::sse::{Event, KeepAlive},
|
||||||
|
response::{IntoResponse, Response, Sse},
|
||||||
|
routing::post,
|
||||||
|
Json, Router,
|
||||||
|
};
|
||||||
|
use futures_util::stream::{self, StreamExt};
|
||||||
|
use serde_json::json;
|
||||||
|
use std::net::SocketAddr;
|
||||||
|
use std::sync::Arc;
|
||||||
|
use tokio::net::TcpListener;
|
||||||
|
|
||||||
|
/// Mock OpenAI API server for testing
|
||||||
|
pub struct MockOpenAIServer {
|
||||||
|
addr: SocketAddr,
|
||||||
|
_handle: tokio::task::JoinHandle<()>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
struct MockServerState {
|
||||||
|
require_auth: bool,
|
||||||
|
expected_auth: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl MockOpenAIServer {
|
||||||
|
/// Create and start a new mock OpenAI server
|
||||||
|
pub async fn new() -> Self {
|
||||||
|
Self::new_with_auth(None).await
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create and start a new mock OpenAI server with optional auth requirement
|
||||||
|
pub async fn new_with_auth(expected_auth: Option<String>) -> Self {
|
||||||
|
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||||
|
let addr = listener.local_addr().unwrap();
|
||||||
|
|
||||||
|
let state = Arc::new(MockServerState {
|
||||||
|
require_auth: expected_auth.is_some(),
|
||||||
|
expected_auth,
|
||||||
|
});
|
||||||
|
|
||||||
|
let app = Router::new()
|
||||||
|
.route("/v1/chat/completions", post(mock_chat_completions))
|
||||||
|
.route("/v1/completions", post(mock_completions))
|
||||||
|
.route("/v1/models", post(mock_models).get(mock_models))
|
||||||
|
.with_state(state);
|
||||||
|
|
||||||
|
let handle = tokio::spawn(async move {
|
||||||
|
axum::serve(listener, app).await.unwrap();
|
||||||
|
});
|
||||||
|
|
||||||
|
// Give the server a moment to start
|
||||||
|
tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
|
||||||
|
|
||||||
|
Self {
|
||||||
|
addr,
|
||||||
|
_handle: handle,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get the base URL for this mock server
|
||||||
|
pub fn base_url(&self) -> String {
|
||||||
|
format!("http://{}", self.addr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Mock chat completions endpoint
|
||||||
|
async fn mock_chat_completions(req: Request<Body>) -> Response {
|
||||||
|
let (_, body) = req.into_parts();
|
||||||
|
let body_bytes = match axum::body::to_bytes(body, usize::MAX).await {
|
||||||
|
Ok(bytes) => bytes,
|
||||||
|
Err(_) => return StatusCode::BAD_REQUEST.into_response(),
|
||||||
|
};
|
||||||
|
|
||||||
|
let request: serde_json::Value = match serde_json::from_slice(&body_bytes) {
|
||||||
|
Ok(req) => req,
|
||||||
|
Err(_) => return StatusCode::BAD_REQUEST.into_response(),
|
||||||
|
};
|
||||||
|
|
||||||
|
// Extract model from request or use default (owned String to satisfy 'static in stream)
|
||||||
|
let model: String = request
|
||||||
|
.get("model")
|
||||||
|
.and_then(|v| v.as_str())
|
||||||
|
.unwrap_or("gpt-3.5-turbo")
|
||||||
|
.to_string();
|
||||||
|
|
||||||
|
// If stream requested, return SSE
|
||||||
|
let is_stream = request
|
||||||
|
.get("stream")
|
||||||
|
.and_then(|v| v.as_bool())
|
||||||
|
.unwrap_or(false);
|
||||||
|
|
||||||
|
if is_stream {
|
||||||
|
let created = 1677652288u64;
|
||||||
|
// Single chunk then [DONE]
|
||||||
|
let model_chunk = model.clone();
|
||||||
|
let event_stream = stream::once(async move {
|
||||||
|
let chunk = json!({
|
||||||
|
"id": "chatcmpl-123456789",
|
||||||
|
"object": "chat.completion.chunk",
|
||||||
|
"created": created,
|
||||||
|
"model": model_chunk,
|
||||||
|
"choices": [{
|
||||||
|
"index": 0,
|
||||||
|
"delta": {
|
||||||
|
"content": "Hello!"
|
||||||
|
},
|
||||||
|
"finish_reason": null
|
||||||
|
}]
|
||||||
|
});
|
||||||
|
Ok::<_, std::convert::Infallible>(Event::default().data(chunk.to_string()))
|
||||||
|
})
|
||||||
|
.chain(stream::once(async { Ok(Event::default().data("[DONE]")) }));
|
||||||
|
|
||||||
|
Sse::new(event_stream)
|
||||||
|
.keep_alive(KeepAlive::default())
|
||||||
|
.into_response()
|
||||||
|
} else {
|
||||||
|
// Create a mock non-streaming response
|
||||||
|
let response = json!({
|
||||||
|
"id": "chatcmpl-123456789",
|
||||||
|
"object": "chat.completion",
|
||||||
|
"created": 1677652288,
|
||||||
|
"model": model,
|
||||||
|
"choices": [{
|
||||||
|
"index": 0,
|
||||||
|
"message": {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "Hello! I'm a mock OpenAI assistant. How can I help you today?"
|
||||||
|
},
|
||||||
|
"finish_reason": "stop"
|
||||||
|
}],
|
||||||
|
"usage": {
|
||||||
|
"prompt_tokens": 9,
|
||||||
|
"completion_tokens": 12,
|
||||||
|
"total_tokens": 21
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
Json(response).into_response()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Mock completions endpoint (legacy)
|
||||||
|
async fn mock_completions(req: Request<Body>) -> Response {
|
||||||
|
let (_, body) = req.into_parts();
|
||||||
|
let body_bytes = match axum::body::to_bytes(body, usize::MAX).await {
|
||||||
|
Ok(bytes) => bytes,
|
||||||
|
Err(_) => return StatusCode::BAD_REQUEST.into_response(),
|
||||||
|
};
|
||||||
|
|
||||||
|
let request: serde_json::Value = match serde_json::from_slice(&body_bytes) {
|
||||||
|
Ok(req) => req,
|
||||||
|
Err(_) => return StatusCode::BAD_REQUEST.into_response(),
|
||||||
|
};
|
||||||
|
|
||||||
|
let model = request["model"].as_str().unwrap_or("text-davinci-003");
|
||||||
|
|
||||||
|
let response = json!({
|
||||||
|
"id": "cmpl-123456789",
|
||||||
|
"object": "text_completion",
|
||||||
|
"created": 1677652288,
|
||||||
|
"model": model,
|
||||||
|
"choices": [{
|
||||||
|
"text": " This is a mock completion response.",
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": null,
|
||||||
|
"finish_reason": "stop"
|
||||||
|
}],
|
||||||
|
"usage": {
|
||||||
|
"prompt_tokens": 5,
|
||||||
|
"completion_tokens": 7,
|
||||||
|
"total_tokens": 12
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
Json(response).into_response()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Mock models endpoint
|
||||||
|
async fn mock_models(State(state): State<Arc<MockServerState>>, req: Request<Body>) -> Response {
|
||||||
|
// Optionally enforce Authorization header
|
||||||
|
if state.require_auth {
|
||||||
|
let auth = req
|
||||||
|
.headers()
|
||||||
|
.get("authorization")
|
||||||
|
.or_else(|| req.headers().get("Authorization"))
|
||||||
|
.and_then(|v| v.to_str().ok())
|
||||||
|
.map(|s| s.to_string());
|
||||||
|
let auth_ok = match (&state.expected_auth, auth) {
|
||||||
|
(Some(expected), Some(got)) => &got == expected,
|
||||||
|
(None, Some(_)) => true,
|
||||||
|
_ => false,
|
||||||
|
};
|
||||||
|
if !auth_ok {
|
||||||
|
let mut response = Response::new(Body::from(
|
||||||
|
json!({
|
||||||
|
"error": {
|
||||||
|
"message": "Unauthorized",
|
||||||
|
"type": "invalid_request_error"
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.to_string(),
|
||||||
|
));
|
||||||
|
*response.status_mut() = StatusCode::UNAUTHORIZED;
|
||||||
|
response
|
||||||
|
.headers_mut()
|
||||||
|
.insert("WWW-Authenticate", HeaderValue::from_static("Bearer"));
|
||||||
|
return response;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let response = json!({
|
||||||
|
"object": "list",
|
||||||
|
"data": [
|
||||||
|
{
|
||||||
|
"id": "gpt-4",
|
||||||
|
"object": "model",
|
||||||
|
"created": 1677610602,
|
||||||
|
"owned_by": "openai"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "gpt-3.5-turbo",
|
||||||
|
"object": "model",
|
||||||
|
"created": 1677610602,
|
||||||
|
"owned_by": "openai"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
});
|
||||||
|
|
||||||
|
Json(response).into_response()
|
||||||
|
}
|
||||||
0
sgl-router/tests/common/mock_worker.rs
Normal file → Executable file
0
sgl-router/tests/common/mock_worker.rs
Normal file → Executable file
@@ -2,6 +2,7 @@
|
|||||||
#![allow(dead_code)]
|
#![allow(dead_code)]
|
||||||
|
|
||||||
pub mod mock_mcp_server;
|
pub mod mock_mcp_server;
|
||||||
|
pub mod mock_openai_server;
|
||||||
pub mod mock_worker;
|
pub mod mock_worker;
|
||||||
pub mod test_app;
|
pub mod test_app;
|
||||||
|
|
||||||
|
|||||||
419
sgl-router/tests/test_openai_routing.rs
Normal file
419
sgl-router/tests/test_openai_routing.rs
Normal file
@@ -0,0 +1,419 @@
|
|||||||
|
//! Comprehensive integration tests for OpenAI backend functionality
|
||||||
|
|
||||||
|
use axum::{
|
||||||
|
body::Body,
|
||||||
|
extract::Request,
|
||||||
|
http::{Method, StatusCode},
|
||||||
|
routing::post,
|
||||||
|
Router,
|
||||||
|
};
|
||||||
|
use serde_json::json;
|
||||||
|
use sglang_router_rs::{
|
||||||
|
config::{RouterConfig, RoutingMode},
|
||||||
|
protocols::spec::{
|
||||||
|
ChatCompletionRequest, ChatMessage, CompletionRequest, GenerateRequest, UserMessageContent,
|
||||||
|
},
|
||||||
|
routers::{openai_router::OpenAIRouter, RouterTrait},
|
||||||
|
};
|
||||||
|
use std::sync::Arc;
|
||||||
|
use tower::ServiceExt;
|
||||||
|
|
||||||
|
mod common;
|
||||||
|
use common::mock_openai_server::MockOpenAIServer;
|
||||||
|
|
||||||
|
/// Helper function to create a minimal chat completion request for testing
|
||||||
|
fn create_minimal_chat_request() -> ChatCompletionRequest {
|
||||||
|
let val = json!({
|
||||||
|
"model": "gpt-3.5-turbo",
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": "Hello"}
|
||||||
|
],
|
||||||
|
"max_tokens": 100
|
||||||
|
});
|
||||||
|
serde_json::from_value(val).unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Helper function to create a minimal completion request for testing
|
||||||
|
fn create_minimal_completion_request() -> CompletionRequest {
|
||||||
|
CompletionRequest {
|
||||||
|
model: "gpt-3.5-turbo".to_string(),
|
||||||
|
prompt: sglang_router_rs::protocols::spec::StringOrArray::String("Hello".to_string()),
|
||||||
|
suffix: None,
|
||||||
|
max_tokens: Some(100),
|
||||||
|
temperature: None,
|
||||||
|
top_p: None,
|
||||||
|
n: None,
|
||||||
|
stream: false,
|
||||||
|
stream_options: None,
|
||||||
|
logprobs: None,
|
||||||
|
echo: false,
|
||||||
|
stop: None,
|
||||||
|
presence_penalty: None,
|
||||||
|
frequency_penalty: None,
|
||||||
|
best_of: None,
|
||||||
|
logit_bias: None,
|
||||||
|
user: None,
|
||||||
|
seed: None,
|
||||||
|
top_k: None,
|
||||||
|
min_p: None,
|
||||||
|
min_tokens: None,
|
||||||
|
repetition_penalty: None,
|
||||||
|
regex: None,
|
||||||
|
ebnf: None,
|
||||||
|
json_schema: None,
|
||||||
|
stop_token_ids: None,
|
||||||
|
no_stop_trim: false,
|
||||||
|
ignore_eos: false,
|
||||||
|
skip_special_tokens: true,
|
||||||
|
lora_path: None,
|
||||||
|
session_params: None,
|
||||||
|
return_hidden_states: false,
|
||||||
|
other: serde_json::Map::new(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============= Basic Unit Tests =============
|
||||||
|
|
||||||
|
/// Test basic OpenAI router creation and configuration
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_openai_router_creation() {
|
||||||
|
let router = OpenAIRouter::new("https://api.openai.com".to_string(), None).await;
|
||||||
|
|
||||||
|
assert!(router.is_ok(), "Router creation should succeed");
|
||||||
|
|
||||||
|
let router = router.unwrap();
|
||||||
|
assert_eq!(router.router_type(), "openai");
|
||||||
|
assert!(!router.is_pd_mode());
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Test health endpoints
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_openai_router_health() {
|
||||||
|
let router = OpenAIRouter::new("https://api.openai.com".to_string(), None)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let req = Request::builder()
|
||||||
|
.method(Method::GET)
|
||||||
|
.uri("/health")
|
||||||
|
.body(Body::empty())
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let response = router.health(req).await;
|
||||||
|
assert_eq!(response.status(), StatusCode::OK);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Test server info endpoint
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_openai_router_server_info() {
|
||||||
|
let router = OpenAIRouter::new("https://api.openai.com".to_string(), None)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let req = Request::builder()
|
||||||
|
.method(Method::GET)
|
||||||
|
.uri("/info")
|
||||||
|
.body(Body::empty())
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let response = router.get_server_info(req).await;
|
||||||
|
assert_eq!(response.status(), StatusCode::OK);
|
||||||
|
|
||||||
|
let (_, body) = response.into_parts();
|
||||||
|
let body_bytes = axum::body::to_bytes(body, usize::MAX).await.unwrap();
|
||||||
|
let body_str = String::from_utf8(body_bytes.to_vec()).unwrap();
|
||||||
|
|
||||||
|
assert!(body_str.contains("openai"));
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Test models endpoint
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_openai_router_models() {
|
||||||
|
// Use mock server for deterministic models response
|
||||||
|
let mock_server = MockOpenAIServer::new().await;
|
||||||
|
let router = OpenAIRouter::new(mock_server.base_url(), None)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let req = Request::builder()
|
||||||
|
.method(Method::GET)
|
||||||
|
.uri("/models")
|
||||||
|
.body(Body::empty())
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let response = router.get_models(req).await;
|
||||||
|
assert_eq!(response.status(), StatusCode::OK);
|
||||||
|
|
||||||
|
let (_, body) = response.into_parts();
|
||||||
|
let body_bytes = axum::body::to_bytes(body, usize::MAX).await.unwrap();
|
||||||
|
let body_str = String::from_utf8(body_bytes.to_vec()).unwrap();
|
||||||
|
let models: serde_json::Value = serde_json::from_str(&body_str).unwrap();
|
||||||
|
|
||||||
|
assert_eq!(models["object"], "list");
|
||||||
|
assert!(models["data"].is_array());
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Test router factory with OpenAI routing mode
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_router_factory_openai_mode() {
|
||||||
|
let routing_mode = RoutingMode::OpenAI {
|
||||||
|
worker_urls: vec!["https://api.openai.com".to_string()],
|
||||||
|
};
|
||||||
|
|
||||||
|
let router_config =
|
||||||
|
RouterConfig::new(routing_mode, sglang_router_rs::config::PolicyConfig::Random);
|
||||||
|
|
||||||
|
let app_context = common::create_test_context(router_config);
|
||||||
|
|
||||||
|
let router = sglang_router_rs::routers::RouterFactory::create_router(&app_context).await;
|
||||||
|
assert!(
|
||||||
|
router.is_ok(),
|
||||||
|
"Router factory should create OpenAI router successfully"
|
||||||
|
);
|
||||||
|
|
||||||
|
let router = router.unwrap();
|
||||||
|
assert_eq!(router.router_type(), "openai");
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Test that unsupported endpoints return proper error codes
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_unsupported_endpoints() {
|
||||||
|
let router = OpenAIRouter::new("https://api.openai.com".to_string(), None)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
// Test generate endpoint (SGLang-specific, should not be supported)
|
||||||
|
let generate_request = GenerateRequest {
|
||||||
|
prompt: None,
|
||||||
|
text: Some("Hello world".to_string()),
|
||||||
|
input_ids: None,
|
||||||
|
parameters: None,
|
||||||
|
sampling_params: None,
|
||||||
|
stream: false,
|
||||||
|
return_logprob: false,
|
||||||
|
lora_path: None,
|
||||||
|
session_params: None,
|
||||||
|
return_hidden_states: false,
|
||||||
|
rid: None,
|
||||||
|
};
|
||||||
|
|
||||||
|
let response = router.route_generate(None, &generate_request).await;
|
||||||
|
assert_eq!(response.status(), StatusCode::NOT_IMPLEMENTED);
|
||||||
|
|
||||||
|
// Test completion endpoint (should also not be supported)
|
||||||
|
let completion_request = create_minimal_completion_request();
|
||||||
|
let response = router.route_completion(None, &completion_request).await;
|
||||||
|
assert_eq!(response.status(), StatusCode::NOT_IMPLEMENTED);
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============= Mock Server E2E Tests =============
|
||||||
|
|
||||||
|
/// Test chat completion with mock OpenAI server
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_openai_router_chat_completion_with_mock() {
|
||||||
|
// Start a mock OpenAI server
|
||||||
|
let mock_server = MockOpenAIServer::new().await;
|
||||||
|
let base_url = mock_server.base_url();
|
||||||
|
|
||||||
|
// Create router pointing to mock server
|
||||||
|
let router = OpenAIRouter::new(base_url, None).await.unwrap();
|
||||||
|
|
||||||
|
// Create a minimal chat completion request
|
||||||
|
let mut chat_request = create_minimal_chat_request();
|
||||||
|
chat_request.messages = vec![ChatMessage::User {
|
||||||
|
role: "user".to_string(),
|
||||||
|
content: UserMessageContent::Text("Hello, how are you?".to_string()),
|
||||||
|
name: None,
|
||||||
|
}];
|
||||||
|
chat_request.temperature = Some(0.7);
|
||||||
|
|
||||||
|
// Route the request
|
||||||
|
let response = router.route_chat(None, &chat_request).await;
|
||||||
|
|
||||||
|
// Should get a successful response from mock server
|
||||||
|
assert_eq!(response.status(), StatusCode::OK);
|
||||||
|
|
||||||
|
let (_, body) = response.into_parts();
|
||||||
|
let body_bytes = axum::body::to_bytes(body, usize::MAX).await.unwrap();
|
||||||
|
let body_str = String::from_utf8(body_bytes.to_vec()).unwrap();
|
||||||
|
let chat_response: serde_json::Value = serde_json::from_str(&body_str).unwrap();
|
||||||
|
|
||||||
|
// Verify it's a valid chat completion response
|
||||||
|
assert_eq!(chat_response["object"], "chat.completion");
|
||||||
|
assert_eq!(chat_response["model"], "gpt-3.5-turbo");
|
||||||
|
assert!(!chat_response["choices"].as_array().unwrap().is_empty());
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Test full E2E flow with Axum server
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_openai_e2e_with_server() {
|
||||||
|
// Start mock OpenAI server
|
||||||
|
let mock_server = MockOpenAIServer::new().await;
|
||||||
|
let base_url = mock_server.base_url();
|
||||||
|
|
||||||
|
// Create router
|
||||||
|
let router = OpenAIRouter::new(base_url, None).await.unwrap();
|
||||||
|
|
||||||
|
// Create Axum app with chat completions endpoint
|
||||||
|
let app = Router::new().route(
|
||||||
|
"/v1/chat/completions",
|
||||||
|
post({
|
||||||
|
let router = Arc::new(router);
|
||||||
|
move |req: Request<Body>| {
|
||||||
|
let router = router.clone();
|
||||||
|
async move {
|
||||||
|
let (parts, body) = req.into_parts();
|
||||||
|
let body_bytes = axum::body::to_bytes(body, usize::MAX).await.unwrap();
|
||||||
|
let body_str = String::from_utf8(body_bytes.to_vec()).unwrap();
|
||||||
|
|
||||||
|
let chat_request: ChatCompletionRequest =
|
||||||
|
serde_json::from_str(&body_str).unwrap();
|
||||||
|
|
||||||
|
router.route_chat(Some(&parts.headers), &chat_request).await
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}),
|
||||||
|
);
|
||||||
|
|
||||||
|
// Make a request to the server
|
||||||
|
let request = Request::builder()
|
||||||
|
.method(Method::POST)
|
||||||
|
.uri("/v1/chat/completions")
|
||||||
|
.header("content-type", "application/json")
|
||||||
|
.body(Body::from(
|
||||||
|
json!({
|
||||||
|
"model": "gpt-3.5-turbo",
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Hello, world!"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"max_tokens": 100
|
||||||
|
})
|
||||||
|
.to_string(),
|
||||||
|
))
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let response = app.oneshot(request).await.unwrap();
|
||||||
|
assert_eq!(response.status(), StatusCode::OK);
|
||||||
|
|
||||||
|
let body = axum::body::to_bytes(response.into_body(), usize::MAX)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
let response_json: serde_json::Value = serde_json::from_slice(&body).unwrap();
|
||||||
|
|
||||||
|
// Verify the response structure
|
||||||
|
assert_eq!(response_json["object"], "chat.completion");
|
||||||
|
assert_eq!(response_json["model"], "gpt-3.5-turbo");
|
||||||
|
assert!(!response_json["choices"].as_array().unwrap().is_empty());
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Test streaming chat completions pass-through with mock server
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_openai_router_chat_streaming_with_mock() {
|
||||||
|
let mock_server = MockOpenAIServer::new().await;
|
||||||
|
let base_url = mock_server.base_url();
|
||||||
|
let router = OpenAIRouter::new(base_url, None).await.unwrap();
|
||||||
|
|
||||||
|
// Build a streaming chat request
|
||||||
|
let val = json!({
|
||||||
|
"model": "gpt-3.5-turbo",
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": "Hello"}
|
||||||
|
],
|
||||||
|
"max_tokens": 10,
|
||||||
|
"stream": true
|
||||||
|
});
|
||||||
|
let chat_request: ChatCompletionRequest = serde_json::from_value(val).unwrap();
|
||||||
|
|
||||||
|
let response = router.route_chat(None, &chat_request).await;
|
||||||
|
assert_eq!(response.status(), StatusCode::OK);
|
||||||
|
|
||||||
|
// Should be SSE
|
||||||
|
let headers = response.headers();
|
||||||
|
let ct = headers
|
||||||
|
.get("content-type")
|
||||||
|
.unwrap()
|
||||||
|
.to_str()
|
||||||
|
.unwrap()
|
||||||
|
.to_ascii_lowercase();
|
||||||
|
assert!(ct.contains("text/event-stream"));
|
||||||
|
|
||||||
|
// Read entire stream body and assert chunks + DONE
|
||||||
|
let body = axum::body::to_bytes(response.into_body(), usize::MAX)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
let text = String::from_utf8(body.to_vec()).unwrap();
|
||||||
|
assert!(text.contains("chat.completion.chunk"));
|
||||||
|
assert!(text.contains("[DONE]"));
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Test circuit breaker functionality
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_openai_router_circuit_breaker() {
|
||||||
|
// Create router with circuit breaker config
|
||||||
|
let cb_config = sglang_router_rs::config::CircuitBreakerConfig {
|
||||||
|
failure_threshold: 2,
|
||||||
|
success_threshold: 1,
|
||||||
|
timeout_duration_secs: 1,
|
||||||
|
window_duration_secs: 10,
|
||||||
|
};
|
||||||
|
|
||||||
|
let router = OpenAIRouter::new(
|
||||||
|
"http://invalid-url-that-will-fail".to_string(),
|
||||||
|
Some(cb_config),
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let chat_request = create_minimal_chat_request();
|
||||||
|
|
||||||
|
// First few requests should fail and record failures
|
||||||
|
for _ in 0..3 {
|
||||||
|
let response = router.route_chat(None, &chat_request).await;
|
||||||
|
// Should get either an error or circuit breaker response
|
||||||
|
assert!(
|
||||||
|
response.status() == StatusCode::INTERNAL_SERVER_ERROR
|
||||||
|
|| response.status() == StatusCode::SERVICE_UNAVAILABLE
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Test that Authorization header is forwarded in /v1/models
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_openai_router_models_auth_forwarding() {
|
||||||
|
// Start a mock server that requires Authorization
|
||||||
|
let expected_auth = "Bearer test-token".to_string();
|
||||||
|
let mock_server = MockOpenAIServer::new_with_auth(Some(expected_auth.clone())).await;
|
||||||
|
let router = OpenAIRouter::new(mock_server.base_url(), None)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
// 1) Without auth header -> expect 401
|
||||||
|
let req = Request::builder()
|
||||||
|
.method(Method::GET)
|
||||||
|
.uri("/models")
|
||||||
|
.body(Body::empty())
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let response = router.get_models(req).await;
|
||||||
|
assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
|
||||||
|
|
||||||
|
// 2) With auth header -> expect 200
|
||||||
|
let req = Request::builder()
|
||||||
|
.method(Method::GET)
|
||||||
|
.uri("/models")
|
||||||
|
.header("Authorization", expected_auth)
|
||||||
|
.body(Body::empty())
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let response = router.get_models(req).await;
|
||||||
|
assert_eq!(response.status(), StatusCode::OK);
|
||||||
|
|
||||||
|
let (_, body) = response.into_parts();
|
||||||
|
let body_bytes = axum::body::to_bytes(body, usize::MAX).await.unwrap();
|
||||||
|
let body_str = String::from_utf8(body_bytes.to_vec()).unwrap();
|
||||||
|
let models: serde_json::Value = serde_json::from_str(&body_str).unwrap();
|
||||||
|
assert_eq!(models["object"], "list");
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user