[router] migrate router from actix to axum (#8479)
This commit is contained in:
@@ -10,41 +10,41 @@ name = "sglang_router_rs"
|
|||||||
crate-type = ["cdylib", "rlib"]
|
crate-type = ["cdylib", "rlib"]
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
actix-web = "4.0"
|
axum = { version = "0.8.4", features = ["macros", "ws", "tracing"] }
|
||||||
|
tower = { version = "0.5", features = ["full"] }
|
||||||
|
tower-http = { version = "0.6", features = ["trace", "compression-gzip", "cors", "timeout", "limit", "request-id", "util"] }
|
||||||
serde = { version = "1.0", features = ["derive"] }
|
serde = { version = "1.0", features = ["derive"] }
|
||||||
clap = { version = "4.4", features = ["derive"] }
|
serde_json = "1.0"
|
||||||
bytes = "1.8.0"
|
bytes = "1.8.0"
|
||||||
rand = "0.8.5"
|
rand = "0.8.5"
|
||||||
reqwest = { version = "0.12.8", features = ["stream", "blocking", "json"] }
|
reqwest = { version = "0.12.8", features = ["stream", "blocking", "json"] }
|
||||||
futures-util = "0.3"
|
futures-util = "0.3"
|
||||||
serde_json = "1.0"
|
futures = "0.3"
|
||||||
pyo3 = { version = "0.22.5", features = ["extension-module"] }
|
pyo3 = { version = "0.22.5", features = ["extension-module"] }
|
||||||
dashmap = "6.1.0"
|
dashmap = "6.1.0"
|
||||||
http = "1.1.0"
|
http = "1.1.0"
|
||||||
tokio = { version = "1.42.0", features = ["macros", "rt-multi-thread"] }
|
tokio = { version = "1.42.0", features = ["full"] }
|
||||||
# Added for enhanced logging system
|
async-trait = "0.1"
|
||||||
|
once_cell = "1.21"
|
||||||
tracing = "0.1"
|
tracing = "0.1"
|
||||||
tracing-subscriber = { version = "0.3", features = ["env-filter", "json", "chrono"] }
|
tracing-subscriber = { version = "0.3", features = ["env-filter", "json", "chrono"] }
|
||||||
tracing-log = "0.2"
|
tracing-log = "0.2"
|
||||||
tracing-appender = "0.2.3"
|
tracing-appender = "0.2.3"
|
||||||
|
chrono = "0.4"
|
||||||
kube = { version = "0.88.1", features = ["runtime", "derive"] }
|
kube = { version = "0.88.1", features = ["runtime", "derive"] }
|
||||||
k8s-openapi = { version = "0.21.0", features = ["v1_29"] }
|
k8s-openapi = { version = "0.21.0", features = ["v1_29"] }
|
||||||
futures = "0.3"
|
|
||||||
async-trait = "0.1"
|
|
||||||
once_cell = "1.21"
|
|
||||||
# Added for metrics
|
|
||||||
metrics = "0.24.2"
|
metrics = "0.24.2"
|
||||||
metrics-exporter-prometheus = "0.17.0"
|
metrics-exporter-prometheus = "0.17.0"
|
||||||
# Added for request tracing
|
|
||||||
uuid = { version = "1.10", features = ["v4", "serde"] }
|
uuid = { version = "1.10", features = ["v4", "serde"] }
|
||||||
thiserror = "2.0.12"
|
thiserror = "2.0.12"
|
||||||
url = "2.5.4"
|
url = "2.5.4"
|
||||||
|
tokio-stream = { version = "0.1", features = ["sync"] }
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
criterion = { version = "0.5", features = ["html_reports"] }
|
criterion = { version = "0.5", features = ["html_reports"] }
|
||||||
tokio-stream = "0.1"
|
tower = { version = "0.5", features = ["util"] }
|
||||||
actix-http = "3.0"
|
http-body-util = "0.1"
|
||||||
futures = "0.3"
|
portpicker = "0.1"
|
||||||
|
|
||||||
[[bench]]
|
[[bench]]
|
||||||
name = "request_processing"
|
name = "request_processing"
|
||||||
|
|||||||
@@ -68,6 +68,12 @@ class RouterArgs:
|
|||||||
prometheus_host: Optional[str] = None
|
prometheus_host: Optional[str] = None
|
||||||
# Request ID headers configuration
|
# Request ID headers configuration
|
||||||
request_id_headers: Optional[List[str]] = None
|
request_id_headers: Optional[List[str]] = None
|
||||||
|
# Request timeout in seconds
|
||||||
|
request_timeout_secs: int = 600
|
||||||
|
# Max concurrent requests for rate limiting
|
||||||
|
max_concurrent_requests: int = 64
|
||||||
|
# CORS allowed origins
|
||||||
|
cors_allowed_origins: List[str] = dataclasses.field(default_factory=list)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def add_cli_args(
|
def add_cli_args(
|
||||||
@@ -276,6 +282,25 @@ class RouterArgs:
|
|||||||
nargs="*",
|
nargs="*",
|
||||||
help="Custom HTTP headers to check for request IDs (e.g., x-request-id x-trace-id). If not specified, uses common defaults.",
|
help="Custom HTTP headers to check for request IDs (e.g., x-request-id x-trace-id). If not specified, uses common defaults.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
f"--{prefix}request-timeout-secs",
|
||||||
|
type=int,
|
||||||
|
default=RouterArgs.request_timeout_secs,
|
||||||
|
help="Request timeout in seconds",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
f"--{prefix}max-concurrent-requests",
|
||||||
|
type=int,
|
||||||
|
default=RouterArgs.max_concurrent_requests,
|
||||||
|
help="Maximum number of concurrent requests allowed (for rate limiting)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
f"--{prefix}cors-allowed-origins",
|
||||||
|
type=str,
|
||||||
|
nargs="*",
|
||||||
|
default=[],
|
||||||
|
help="CORS allowed origins (e.g., http://localhost:3000 https://example.com)",
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_cli_args(
|
def from_cli_args(
|
||||||
@@ -337,6 +362,15 @@ class RouterArgs:
|
|||||||
prometheus_port=getattr(args, f"{prefix}prometheus_port", None),
|
prometheus_port=getattr(args, f"{prefix}prometheus_port", None),
|
||||||
prometheus_host=getattr(args, f"{prefix}prometheus_host", None),
|
prometheus_host=getattr(args, f"{prefix}prometheus_host", None),
|
||||||
request_id_headers=getattr(args, f"{prefix}request_id_headers", None),
|
request_id_headers=getattr(args, f"{prefix}request_id_headers", None),
|
||||||
|
request_timeout_secs=getattr(
|
||||||
|
args, f"{prefix}request_timeout_secs", RouterArgs.request_timeout_secs
|
||||||
|
),
|
||||||
|
max_concurrent_requests=getattr(
|
||||||
|
args,
|
||||||
|
f"{prefix}max_concurrent_requests",
|
||||||
|
RouterArgs.max_concurrent_requests,
|
||||||
|
),
|
||||||
|
cors_allowed_origins=getattr(args, f"{prefix}cors_allowed_origins", []),
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -490,6 +524,7 @@ def launch_router(args: argparse.Namespace) -> Optional[Router]:
|
|||||||
decode_selector=router_args.decode_selector,
|
decode_selector=router_args.decode_selector,
|
||||||
prometheus_port=router_args.prometheus_port,
|
prometheus_port=router_args.prometheus_port,
|
||||||
prometheus_host=router_args.prometheus_host,
|
prometheus_host=router_args.prometheus_host,
|
||||||
|
request_timeout_secs=router_args.request_timeout_secs,
|
||||||
pd_disaggregation=router_args.pd_disaggregation,
|
pd_disaggregation=router_args.pd_disaggregation,
|
||||||
prefill_urls=(
|
prefill_urls=(
|
||||||
router_args.prefill_urls if router_args.pd_disaggregation else None
|
router_args.prefill_urls if router_args.pd_disaggregation else None
|
||||||
@@ -508,6 +543,8 @@ def launch_router(args: argparse.Namespace) -> Optional[Router]:
|
|||||||
else None
|
else None
|
||||||
),
|
),
|
||||||
request_id_headers=router_args.request_id_headers,
|
request_id_headers=router_args.request_id_headers,
|
||||||
|
max_concurrent_requests=router_args.max_concurrent_requests,
|
||||||
|
cors_allowed_origins=router_args.cors_allowed_origins,
|
||||||
)
|
)
|
||||||
|
|
||||||
router.start()
|
router.start()
|
||||||
|
|||||||
@@ -61,6 +61,11 @@ class Router:
|
|||||||
request_id_headers: List of HTTP headers to check for request IDs. If not specified,
|
request_id_headers: List of HTTP headers to check for request IDs. If not specified,
|
||||||
uses common defaults: ['x-request-id', 'x-correlation-id', 'x-trace-id', 'request-id'].
|
uses common defaults: ['x-request-id', 'x-correlation-id', 'x-trace-id', 'request-id'].
|
||||||
Example: ['x-my-request-id', 'x-custom-trace-id']. Default: None
|
Example: ['x-my-request-id', 'x-custom-trace-id']. Default: None
|
||||||
|
bootstrap_port_annotation: Kubernetes annotation name for bootstrap port (PD mode).
|
||||||
|
Default: 'sglang.ai/bootstrap-port'
|
||||||
|
request_timeout_secs: Request timeout in seconds. Default: 600
|
||||||
|
max_concurrent_requests: Maximum number of concurrent requests allowed for rate limiting. Default: 64
|
||||||
|
cors_allowed_origins: List of allowed origins for CORS. Empty list allows all origins. Default: []
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -87,14 +92,18 @@ class Router:
|
|||||||
service_discovery_namespace: Optional[str] = None,
|
service_discovery_namespace: Optional[str] = None,
|
||||||
prefill_selector: Dict[str, str] = None,
|
prefill_selector: Dict[str, str] = None,
|
||||||
decode_selector: Dict[str, str] = None,
|
decode_selector: Dict[str, str] = None,
|
||||||
|
bootstrap_port_annotation: str = "sglang.ai/bootstrap-port",
|
||||||
prometheus_port: Optional[int] = None,
|
prometheus_port: Optional[int] = None,
|
||||||
prometheus_host: Optional[str] = None,
|
prometheus_host: Optional[str] = None,
|
||||||
|
request_timeout_secs: int = 600,
|
||||||
|
request_id_headers: Optional[List[str]] = None,
|
||||||
pd_disaggregation: bool = False,
|
pd_disaggregation: bool = False,
|
||||||
prefill_urls: Optional[List[tuple]] = None,
|
prefill_urls: Optional[List[tuple]] = None,
|
||||||
decode_urls: Optional[List[str]] = None,
|
decode_urls: Optional[List[str]] = None,
|
||||||
prefill_policy: Optional[PolicyType] = None,
|
prefill_policy: Optional[PolicyType] = None,
|
||||||
decode_policy: Optional[PolicyType] = None,
|
decode_policy: Optional[PolicyType] = None,
|
||||||
request_id_headers: Optional[List[str]] = None,
|
max_concurrent_requests: int = 64,
|
||||||
|
cors_allowed_origins: List[str] = None,
|
||||||
):
|
):
|
||||||
if selector is None:
|
if selector is None:
|
||||||
selector = {}
|
selector = {}
|
||||||
@@ -102,6 +111,8 @@ class Router:
|
|||||||
prefill_selector = {}
|
prefill_selector = {}
|
||||||
if decode_selector is None:
|
if decode_selector is None:
|
||||||
decode_selector = {}
|
decode_selector = {}
|
||||||
|
if cors_allowed_origins is None:
|
||||||
|
cors_allowed_origins = []
|
||||||
|
|
||||||
self._router = _Router(
|
self._router = _Router(
|
||||||
worker_urls=worker_urls,
|
worker_urls=worker_urls,
|
||||||
@@ -126,14 +137,18 @@ class Router:
|
|||||||
service_discovery_namespace=service_discovery_namespace,
|
service_discovery_namespace=service_discovery_namespace,
|
||||||
prefill_selector=prefill_selector,
|
prefill_selector=prefill_selector,
|
||||||
decode_selector=decode_selector,
|
decode_selector=decode_selector,
|
||||||
|
bootstrap_port_annotation=bootstrap_port_annotation,
|
||||||
prometheus_port=prometheus_port,
|
prometheus_port=prometheus_port,
|
||||||
prometheus_host=prometheus_host,
|
prometheus_host=prometheus_host,
|
||||||
|
request_timeout_secs=request_timeout_secs,
|
||||||
|
request_id_headers=request_id_headers,
|
||||||
pd_disaggregation=pd_disaggregation,
|
pd_disaggregation=pd_disaggregation,
|
||||||
prefill_urls=prefill_urls,
|
prefill_urls=prefill_urls,
|
||||||
decode_urls=decode_urls,
|
decode_urls=decode_urls,
|
||||||
prefill_policy=prefill_policy,
|
prefill_policy=prefill_policy,
|
||||||
decode_policy=decode_policy,
|
decode_policy=decode_policy,
|
||||||
request_id_headers=request_id_headers,
|
max_concurrent_requests=max_concurrent_requests,
|
||||||
|
cors_allowed_origins=cors_allowed_origins,
|
||||||
)
|
)
|
||||||
|
|
||||||
def start(self) -> None:
|
def start(self) -> None:
|
||||||
|
|||||||
@@ -46,11 +46,12 @@ class TestLaunchRouter(unittest.TestCase):
|
|||||||
dp_aware=False,
|
dp_aware=False,
|
||||||
prometheus_port=None,
|
prometheus_port=None,
|
||||||
prometheus_host=None,
|
prometheus_host=None,
|
||||||
# PD-specific attributes
|
request_timeout_secs=60,
|
||||||
|
max_concurrent_requests=64,
|
||||||
|
cors_allowed_origins=[],
|
||||||
pd_disaggregation=False,
|
pd_disaggregation=False,
|
||||||
prefill=None,
|
prefill=None,
|
||||||
decode=None,
|
decode=None,
|
||||||
# Keep worker_urls for regular mode
|
|
||||||
worker_urls=[],
|
worker_urls=[],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -35,6 +35,10 @@ pub struct RouterConfig {
|
|||||||
pub log_level: Option<String>,
|
pub log_level: Option<String>,
|
||||||
/// Custom request ID headers to check (defaults to common headers)
|
/// Custom request ID headers to check (defaults to common headers)
|
||||||
pub request_id_headers: Option<Vec<String>>,
|
pub request_id_headers: Option<Vec<String>>,
|
||||||
|
/// Maximum concurrent requests allowed (for rate limiting)
|
||||||
|
pub max_concurrent_requests: usize,
|
||||||
|
/// CORS allowed origins
|
||||||
|
pub cors_allowed_origins: Vec<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Routing mode configuration
|
/// Routing mode configuration
|
||||||
@@ -216,6 +220,8 @@ impl Default for RouterConfig {
|
|||||||
log_dir: None,
|
log_dir: None,
|
||||||
log_level: None,
|
log_level: None,
|
||||||
request_id_headers: None,
|
request_id_headers: None,
|
||||||
|
max_concurrent_requests: 64,
|
||||||
|
cors_allowed_origins: vec![],
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -324,6 +330,8 @@ mod tests {
|
|||||||
log_dir: Some("/var/log".to_string()),
|
log_dir: Some("/var/log".to_string()),
|
||||||
log_level: Some("debug".to_string()),
|
log_level: Some("debug".to_string()),
|
||||||
request_id_headers: None,
|
request_id_headers: None,
|
||||||
|
max_concurrent_requests: 64,
|
||||||
|
cors_allowed_origins: vec![],
|
||||||
};
|
};
|
||||||
|
|
||||||
let json = serde_json::to_string(&config).unwrap();
|
let json = serde_json::to_string(&config).unwrap();
|
||||||
@@ -749,6 +757,8 @@ mod tests {
|
|||||||
log_dir: Some("/var/log/sglang".to_string()),
|
log_dir: Some("/var/log/sglang".to_string()),
|
||||||
log_level: Some("info".to_string()),
|
log_level: Some("info".to_string()),
|
||||||
request_id_headers: None,
|
request_id_headers: None,
|
||||||
|
max_concurrent_requests: 64,
|
||||||
|
cors_allowed_origins: vec![],
|
||||||
};
|
};
|
||||||
|
|
||||||
assert!(config.mode.is_pd_mode());
|
assert!(config.mode.is_pd_mode());
|
||||||
@@ -798,6 +808,8 @@ mod tests {
|
|||||||
log_dir: None,
|
log_dir: None,
|
||||||
log_level: Some("debug".to_string()),
|
log_level: Some("debug".to_string()),
|
||||||
request_id_headers: None,
|
request_id_headers: None,
|
||||||
|
max_concurrent_requests: 64,
|
||||||
|
cors_allowed_origins: vec![],
|
||||||
};
|
};
|
||||||
|
|
||||||
assert!(!config.mode.is_pd_mode());
|
assert!(!config.mode.is_pd_mode());
|
||||||
@@ -843,6 +855,8 @@ mod tests {
|
|||||||
log_dir: Some("/opt/logs/sglang".to_string()),
|
log_dir: Some("/opt/logs/sglang".to_string()),
|
||||||
log_level: Some("trace".to_string()),
|
log_level: Some("trace".to_string()),
|
||||||
request_id_headers: None,
|
request_id_headers: None,
|
||||||
|
max_concurrent_requests: 64,
|
||||||
|
cors_allowed_origins: vec![],
|
||||||
};
|
};
|
||||||
|
|
||||||
assert!(config.has_service_discovery());
|
assert!(config.has_service_discovery());
|
||||||
|
|||||||
@@ -60,6 +60,9 @@ struct Router {
|
|||||||
decode_urls: Option<Vec<String>>,
|
decode_urls: Option<Vec<String>>,
|
||||||
prefill_policy: Option<PolicyType>,
|
prefill_policy: Option<PolicyType>,
|
||||||
decode_policy: Option<PolicyType>,
|
decode_policy: Option<PolicyType>,
|
||||||
|
// Additional server config fields
|
||||||
|
max_concurrent_requests: usize,
|
||||||
|
cors_allowed_origins: Vec<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Router {
|
impl Router {
|
||||||
@@ -145,6 +148,8 @@ impl Router {
|
|||||||
log_dir: self.log_dir.clone(),
|
log_dir: self.log_dir.clone(),
|
||||||
log_level: self.log_level.clone(),
|
log_level: self.log_level.clone(),
|
||||||
request_id_headers: self.request_id_headers.clone(),
|
request_id_headers: self.request_id_headers.clone(),
|
||||||
|
max_concurrent_requests: self.max_concurrent_requests,
|
||||||
|
cors_allowed_origins: self.cors_allowed_origins.clone(),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -184,7 +189,9 @@ impl Router {
|
|||||||
prefill_urls = None,
|
prefill_urls = None,
|
||||||
decode_urls = None,
|
decode_urls = None,
|
||||||
prefill_policy = None,
|
prefill_policy = None,
|
||||||
decode_policy = None
|
decode_policy = None,
|
||||||
|
max_concurrent_requests = 64,
|
||||||
|
cors_allowed_origins = vec![]
|
||||||
))]
|
))]
|
||||||
fn new(
|
fn new(
|
||||||
worker_urls: Vec<String>,
|
worker_urls: Vec<String>,
|
||||||
@@ -219,6 +226,8 @@ impl Router {
|
|||||||
decode_urls: Option<Vec<String>>,
|
decode_urls: Option<Vec<String>>,
|
||||||
prefill_policy: Option<PolicyType>,
|
prefill_policy: Option<PolicyType>,
|
||||||
decode_policy: Option<PolicyType>,
|
decode_policy: Option<PolicyType>,
|
||||||
|
max_concurrent_requests: usize,
|
||||||
|
cors_allowed_origins: Vec<String>,
|
||||||
) -> PyResult<Self> {
|
) -> PyResult<Self> {
|
||||||
Ok(Router {
|
Ok(Router {
|
||||||
host,
|
host,
|
||||||
@@ -253,6 +262,8 @@ impl Router {
|
|||||||
decode_urls,
|
decode_urls,
|
||||||
prefill_policy,
|
prefill_policy,
|
||||||
decode_policy,
|
decode_policy,
|
||||||
|
max_concurrent_requests,
|
||||||
|
cors_allowed_origins,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,9 +1,9 @@
|
|||||||
use actix_web::{
|
use axum::{extract::Request, http::HeaderValue, response::Response};
|
||||||
dev::{forward_ready, Service, ServiceRequest, ServiceResponse, Transform},
|
use std::sync::Arc;
|
||||||
Error, HttpMessage, HttpRequest,
|
use std::time::Instant;
|
||||||
};
|
use tower::{Layer, Service};
|
||||||
use futures_util::future::LocalBoxFuture;
|
use tower_http::trace::{MakeSpan, OnRequest, OnResponse, TraceLayer};
|
||||||
use std::future::{ready, Ready};
|
use tracing::{field::Empty, info_span, Span};
|
||||||
|
|
||||||
/// Generate OpenAI-compatible request ID based on endpoint
|
/// Generate OpenAI-compatible request ID based on endpoint
|
||||||
fn generate_request_id(path: &str) -> String {
|
fn generate_request_id(path: &str) -> String {
|
||||||
@@ -31,67 +31,67 @@ fn generate_request_id(path: &str) -> String {
|
|||||||
format!("{}{}", prefix, random_part)
|
format!("{}{}", prefix, random_part)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Extract request ID from request extensions or generate a new one
|
/// Extension type for storing request ID
|
||||||
pub fn get_request_id(req: &HttpRequest) -> String {
|
#[derive(Clone, Debug)]
|
||||||
req.extensions()
|
pub struct RequestId(pub String);
|
||||||
.get::<String>()
|
|
||||||
.cloned()
|
/// Tower Layer for request ID middleware
|
||||||
.unwrap_or_else(|| generate_request_id(req.path()))
|
#[derive(Clone)]
|
||||||
|
pub struct RequestIdLayer {
|
||||||
|
headers: Arc<Vec<String>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Middleware for injecting request ID into request extensions
|
impl RequestIdLayer {
|
||||||
pub struct RequestIdMiddleware {
|
|
||||||
headers: Vec<String>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl RequestIdMiddleware {
|
|
||||||
pub fn new(headers: Vec<String>) -> Self {
|
pub fn new(headers: Vec<String>) -> Self {
|
||||||
Self { headers }
|
Self {
|
||||||
|
headers: Arc::new(headers),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<S, B> Transform<S, ServiceRequest> for RequestIdMiddleware
|
impl<S> Layer<S> for RequestIdLayer {
|
||||||
where
|
type Service = RequestIdMiddleware<S>;
|
||||||
S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
|
|
||||||
S::Future: 'static,
|
|
||||||
B: 'static,
|
|
||||||
{
|
|
||||||
type Response = ServiceResponse<B>;
|
|
||||||
type Error = Error;
|
|
||||||
type InitError = ();
|
|
||||||
type Transform = RequestIdMiddlewareService<S>;
|
|
||||||
type Future = Ready<Result<Self::Transform, Self::InitError>>;
|
|
||||||
|
|
||||||
fn new_transform(&self, service: S) -> Self::Future {
|
fn layer(&self, inner: S) -> Self::Service {
|
||||||
ready(Ok(RequestIdMiddlewareService {
|
RequestIdMiddleware {
|
||||||
service,
|
inner,
|
||||||
headers: self.headers.clone(),
|
headers: self.headers.clone(),
|
||||||
}))
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct RequestIdMiddlewareService<S> {
|
/// Tower Service for request ID middleware
|
||||||
service: S,
|
#[derive(Clone)]
|
||||||
headers: Vec<String>,
|
pub struct RequestIdMiddleware<S> {
|
||||||
|
inner: S,
|
||||||
|
headers: Arc<Vec<String>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<S, B> Service<ServiceRequest> for RequestIdMiddlewareService<S>
|
impl<S> Service<Request> for RequestIdMiddleware<S>
|
||||||
where
|
where
|
||||||
S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
|
S: Service<Request, Response = Response> + Send + 'static,
|
||||||
S::Future: 'static,
|
S::Future: Send + 'static,
|
||||||
B: 'static,
|
|
||||||
{
|
{
|
||||||
type Response = ServiceResponse<B>;
|
type Response = S::Response;
|
||||||
type Error = Error;
|
type Error = S::Error;
|
||||||
type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;
|
type Future = std::pin::Pin<
|
||||||
|
Box<dyn std::future::Future<Output = Result<Self::Response, Self::Error>> + Send>,
|
||||||
|
>;
|
||||||
|
|
||||||
forward_ready!(service);
|
fn poll_ready(
|
||||||
|
&mut self,
|
||||||
|
cx: &mut std::task::Context<'_>,
|
||||||
|
) -> std::task::Poll<Result<(), Self::Error>> {
|
||||||
|
self.inner.poll_ready(cx)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn call(&mut self, mut req: Request) -> Self::Future {
|
||||||
|
let headers = self.headers.clone();
|
||||||
|
|
||||||
fn call(&self, req: ServiceRequest) -> Self::Future {
|
|
||||||
// Extract request ID from headers or generate new one
|
// Extract request ID from headers or generate new one
|
||||||
let mut request_id = None;
|
let mut request_id = None;
|
||||||
|
|
||||||
for header_name in &self.headers {
|
for header_name in headers.iter() {
|
||||||
if let Some(header_value) = req.headers().get(header_name) {
|
if let Some(header_value) = req.headers().get(header_name) {
|
||||||
if let Ok(value) = header_value.to_str() {
|
if let Ok(value) = header_value.to_str() {
|
||||||
request_id = Some(value.to_string());
|
request_id = Some(value.to_string());
|
||||||
@@ -100,12 +100,216 @@ where
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let request_id = request_id.unwrap_or_else(|| generate_request_id(req.path()));
|
let request_id = request_id.unwrap_or_else(|| generate_request_id(req.uri().path()));
|
||||||
|
|
||||||
// Insert request ID into request extensions
|
// Insert request ID into request extensions
|
||||||
req.extensions_mut().insert(request_id);
|
req.extensions_mut().insert(RequestId(request_id.clone()));
|
||||||
|
|
||||||
let fut = self.service.call(req);
|
// Create a span with the request ID for this request
|
||||||
Box::pin(async move { fut.await })
|
let span = tracing::info_span!(
|
||||||
|
"http_request",
|
||||||
|
method = %req.method(),
|
||||||
|
uri = %req.uri(),
|
||||||
|
version = ?req.version(),
|
||||||
|
request_id = %request_id
|
||||||
|
);
|
||||||
|
|
||||||
|
// Log within the span
|
||||||
|
let _enter = span.enter();
|
||||||
|
tracing::info!(
|
||||||
|
target: "sglang_router_rs::request",
|
||||||
|
"started processing request"
|
||||||
|
);
|
||||||
|
drop(_enter);
|
||||||
|
|
||||||
|
// Capture values we need in the async block
|
||||||
|
let method = req.method().clone();
|
||||||
|
let uri = req.uri().clone();
|
||||||
|
let version = req.version();
|
||||||
|
|
||||||
|
// Call the inner service
|
||||||
|
let future = self.inner.call(req);
|
||||||
|
|
||||||
|
Box::pin(async move {
|
||||||
|
let start_time = Instant::now();
|
||||||
|
let mut response = future.await?;
|
||||||
|
let latency = start_time.elapsed();
|
||||||
|
|
||||||
|
// Add request ID to response headers
|
||||||
|
response.headers_mut().insert(
|
||||||
|
"x-request-id",
|
||||||
|
HeaderValue::from_str(&request_id)
|
||||||
|
.unwrap_or_else(|_| HeaderValue::from_static("invalid-request-id")),
|
||||||
|
);
|
||||||
|
|
||||||
|
// Log the response with proper request ID in span
|
||||||
|
let status = response.status();
|
||||||
|
let span = tracing::info_span!(
|
||||||
|
"http_request",
|
||||||
|
method = %method,
|
||||||
|
uri = %uri,
|
||||||
|
version = ?version,
|
||||||
|
request_id = %request_id,
|
||||||
|
status = %status,
|
||||||
|
latency = ?latency
|
||||||
|
);
|
||||||
|
|
||||||
|
let _enter = span.enter();
|
||||||
|
if status.is_server_error() {
|
||||||
|
tracing::error!(
|
||||||
|
target: "sglang_router_rs::response",
|
||||||
|
"request failed with server error"
|
||||||
|
);
|
||||||
|
} else if status.is_client_error() {
|
||||||
|
tracing::warn!(
|
||||||
|
target: "sglang_router_rs::response",
|
||||||
|
"request failed with client error"
|
||||||
|
);
|
||||||
|
} else {
|
||||||
|
tracing::info!(
|
||||||
|
target: "sglang_router_rs::response",
|
||||||
|
"finished processing request"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(response)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============= Logging Middleware =============
|
||||||
|
|
||||||
|
/// Custom span maker that includes request ID
|
||||||
|
#[derive(Clone, Debug)]
|
||||||
|
pub struct RequestSpan;
|
||||||
|
|
||||||
|
impl<B> MakeSpan<B> for RequestSpan {
|
||||||
|
fn make_span(&mut self, request: &Request<B>) -> Span {
|
||||||
|
// Don't try to extract request ID here - it won't be available yet
|
||||||
|
// The RequestIdLayer runs after TraceLayer creates the span
|
||||||
|
info_span!(
|
||||||
|
"http_request",
|
||||||
|
method = %request.method(),
|
||||||
|
uri = %request.uri(),
|
||||||
|
version = ?request.version(),
|
||||||
|
request_id = Empty, // Will be set later
|
||||||
|
status_code = Empty,
|
||||||
|
latency = Empty,
|
||||||
|
error = Empty,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Custom on_request handler
|
||||||
|
#[derive(Clone, Debug)]
|
||||||
|
pub struct RequestLogger;
|
||||||
|
|
||||||
|
impl<B> OnRequest<B> for RequestLogger {
|
||||||
|
fn on_request(&mut self, request: &Request<B>, span: &Span) {
|
||||||
|
let _enter = span.enter();
|
||||||
|
|
||||||
|
// Try to get the request ID from extensions
|
||||||
|
// This will work if RequestIdLayer has already run
|
||||||
|
if let Some(request_id) = request.extensions().get::<RequestId>() {
|
||||||
|
span.record("request_id", &request_id.0.as_str());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Don't log here - we already log in RequestIdService with the proper request_id
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Custom on_response handler
|
||||||
|
#[derive(Clone, Debug)]
|
||||||
|
pub struct ResponseLogger {
|
||||||
|
_start_time: Instant,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for ResponseLogger {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
_start_time: Instant::now(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<B> OnResponse<B> for ResponseLogger {
|
||||||
|
fn on_response(self, response: &Response<B>, latency: std::time::Duration, span: &Span) {
|
||||||
|
let status = response.status();
|
||||||
|
|
||||||
|
// Record these in the span for structured logging/observability tools
|
||||||
|
span.record("status_code", status.as_u16());
|
||||||
|
span.record("latency", format!("{:?}", latency));
|
||||||
|
|
||||||
|
// Don't log here - RequestIdService handles all logging with proper request IDs
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create a configured TraceLayer for HTTP logging
|
||||||
|
/// Note: Actual request/response logging with request IDs is done in RequestIdService
|
||||||
|
pub fn create_logging_layer() -> TraceLayer<
|
||||||
|
tower_http::classify::SharedClassifier<tower_http::classify::ServerErrorsAsFailures>,
|
||||||
|
RequestSpan,
|
||||||
|
RequestLogger,
|
||||||
|
ResponseLogger,
|
||||||
|
> {
|
||||||
|
TraceLayer::new_for_http()
|
||||||
|
.make_span_with(RequestSpan)
|
||||||
|
.on_request(RequestLogger)
|
||||||
|
.on_response(ResponseLogger::default())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Structured logging data for requests
|
||||||
|
#[derive(Debug, serde::Serialize)]
|
||||||
|
pub struct RequestLogEntry {
|
||||||
|
pub timestamp: String,
|
||||||
|
pub request_id: String,
|
||||||
|
pub method: String,
|
||||||
|
pub uri: String,
|
||||||
|
pub status: u16,
|
||||||
|
pub latency_ms: u64,
|
||||||
|
pub user_agent: Option<String>,
|
||||||
|
pub remote_addr: Option<String>,
|
||||||
|
pub error: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Log a request with structured data
|
||||||
|
pub fn log_request(entry: RequestLogEntry) {
|
||||||
|
if entry.status >= 500 {
|
||||||
|
tracing::error!(
|
||||||
|
target: "sglang_router_rs::http",
|
||||||
|
request_id = %entry.request_id,
|
||||||
|
method = %entry.method,
|
||||||
|
uri = %entry.uri,
|
||||||
|
status = entry.status,
|
||||||
|
latency_ms = entry.latency_ms,
|
||||||
|
user_agent = ?entry.user_agent,
|
||||||
|
remote_addr = ?entry.remote_addr,
|
||||||
|
error = ?entry.error,
|
||||||
|
"HTTP request failed"
|
||||||
|
);
|
||||||
|
} else if entry.status >= 400 {
|
||||||
|
tracing::warn!(
|
||||||
|
target: "sglang_router_rs::http",
|
||||||
|
request_id = %entry.request_id,
|
||||||
|
method = %entry.method,
|
||||||
|
uri = %entry.uri,
|
||||||
|
status = entry.status,
|
||||||
|
latency_ms = entry.latency_ms,
|
||||||
|
user_agent = ?entry.user_agent,
|
||||||
|
remote_addr = ?entry.remote_addr,
|
||||||
|
"HTTP request client error"
|
||||||
|
);
|
||||||
|
} else {
|
||||||
|
tracing::info!(
|
||||||
|
target: "sglang_router_rs::http",
|
||||||
|
request_id = %entry.request_id,
|
||||||
|
method = %entry.method,
|
||||||
|
uri = %entry.uri,
|
||||||
|
status = entry.status,
|
||||||
|
latency_ms = entry.latency_ms,
|
||||||
|
user_agent = ?entry.user_agent,
|
||||||
|
remote_addr = ?entry.remote_addr,
|
||||||
|
"HTTP request completed"
|
||||||
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,10 +1,17 @@
|
|||||||
//! Router implementations
|
//! Router implementations
|
||||||
|
|
||||||
use actix_web::{HttpRequest, HttpResponse};
|
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
|
use axum::{
|
||||||
|
body::Body,
|
||||||
|
extract::Request,
|
||||||
|
http::{HeaderMap, StatusCode},
|
||||||
|
response::{IntoResponse, Response},
|
||||||
|
};
|
||||||
use reqwest::Client;
|
use reqwest::Client;
|
||||||
use std::fmt::Debug;
|
use std::fmt::Debug;
|
||||||
|
|
||||||
|
use crate::openai_api_types::{ChatCompletionRequest, CompletionRequest, GenerateRequest};
|
||||||
|
|
||||||
pub mod factory;
|
pub mod factory;
|
||||||
pub mod pd_router;
|
pub mod pd_router;
|
||||||
pub mod pd_types;
|
pub mod pd_types;
|
||||||
@@ -33,54 +40,55 @@ pub trait WorkerManagement: Send + Sync {
|
|||||||
///
|
///
|
||||||
/// This trait provides a unified interface for routing requests,
|
/// This trait provides a unified interface for routing requests,
|
||||||
/// regardless of whether it's a regular router or PD router.
|
/// regardless of whether it's a regular router or PD router.
|
||||||
#[async_trait(?Send)]
|
#[async_trait]
|
||||||
pub trait RouterTrait: Send + Sync + Debug + WorkerManagement {
|
pub trait RouterTrait: Send + Sync + Debug + WorkerManagement {
|
||||||
/// Get a reference to self as Any for downcasting
|
/// Get a reference to self as Any for downcasting
|
||||||
fn as_any(&self) -> &dyn std::any::Any;
|
fn as_any(&self) -> &dyn std::any::Any;
|
||||||
|
|
||||||
/// Route a health check request
|
/// Route a health check request
|
||||||
async fn health(&self, client: &Client, req: &HttpRequest) -> HttpResponse;
|
async fn health(&self, client: &Client, req: Request<Body>) -> Response;
|
||||||
|
|
||||||
/// Route a health generate request
|
/// Route a health generate request
|
||||||
async fn health_generate(&self, client: &Client, req: &HttpRequest) -> HttpResponse;
|
async fn health_generate(&self, client: &Client, req: Request<Body>) -> Response;
|
||||||
|
|
||||||
/// Get server information
|
/// Get server information
|
||||||
async fn get_server_info(&self, client: &Client, req: &HttpRequest) -> HttpResponse;
|
async fn get_server_info(&self, client: &Client, req: Request<Body>) -> Response;
|
||||||
|
|
||||||
/// Get available models
|
/// Get available models
|
||||||
async fn get_models(&self, client: &Client, req: &HttpRequest) -> HttpResponse;
|
async fn get_models(&self, client: &Client, req: Request<Body>) -> Response;
|
||||||
|
|
||||||
/// Get model information
|
/// Get model information
|
||||||
async fn get_model_info(&self, client: &Client, req: &HttpRequest) -> HttpResponse;
|
async fn get_model_info(&self, client: &Client, req: Request<Body>) -> Response;
|
||||||
|
|
||||||
/// Route a generate request
|
/// Route a generate request
|
||||||
async fn route_generate(
|
async fn route_generate(
|
||||||
&self,
|
&self,
|
||||||
client: &Client,
|
client: &Client,
|
||||||
req: &HttpRequest,
|
headers: Option<&HeaderMap>,
|
||||||
body: serde_json::Value,
|
body: &GenerateRequest,
|
||||||
) -> HttpResponse;
|
) -> Response;
|
||||||
|
|
||||||
/// Route a chat completion request
|
/// Route a chat completion request
|
||||||
async fn route_chat(
|
async fn route_chat(
|
||||||
&self,
|
&self,
|
||||||
client: &Client,
|
client: &Client,
|
||||||
req: &HttpRequest,
|
headers: Option<&HeaderMap>,
|
||||||
body: serde_json::Value,
|
body: &ChatCompletionRequest,
|
||||||
) -> HttpResponse;
|
) -> Response;
|
||||||
|
|
||||||
/// Route a completion request
|
/// Route a completion request
|
||||||
async fn route_completion(
|
async fn route_completion(
|
||||||
&self,
|
&self,
|
||||||
client: &Client,
|
client: &Client,
|
||||||
req: &HttpRequest,
|
headers: Option<&HeaderMap>,
|
||||||
body: serde_json::Value,
|
body: &CompletionRequest,
|
||||||
) -> HttpResponse;
|
) -> Response;
|
||||||
|
|
||||||
/// Flush cache on all workers
|
/// Flush cache on all workers
|
||||||
async fn flush_cache(&self, client: &Client) -> HttpResponse;
|
async fn flush_cache(&self, client: &Client) -> Response;
|
||||||
|
|
||||||
/// Get worker loads (for monitoring)
|
/// Get worker loads (for monitoring)
|
||||||
async fn get_worker_loads(&self, client: &Client) -> HttpResponse;
|
async fn get_worker_loads(&self, client: &Client) -> Response;
|
||||||
|
|
||||||
/// Get router type name
|
/// Get router type name
|
||||||
fn router_type(&self) -> &'static str;
|
fn router_type(&self) -> &'static str;
|
||||||
@@ -91,11 +99,11 @@ pub trait RouterTrait: Send + Sync + Debug + WorkerManagement {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Server liveness check - is the server process running
|
/// Server liveness check - is the server process running
|
||||||
fn liveness(&self) -> HttpResponse {
|
fn liveness(&self) -> Response {
|
||||||
// Simple liveness check - if we can respond, we're alive
|
// Simple liveness check - if we can respond, we're alive
|
||||||
HttpResponse::Ok().body("OK")
|
(StatusCode::OK, "OK").into_response()
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Server readiness check - is the server ready to handle requests
|
/// Server readiness check - is the server ready to handle requests
|
||||||
fn readiness(&self) -> HttpResponse;
|
fn readiness(&self) -> Response;
|
||||||
}
|
}
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -1,17 +1,23 @@
|
|||||||
use crate::core::{HealthChecker, Worker, WorkerFactory};
|
use crate::core::{HealthChecker, Worker, WorkerFactory};
|
||||||
use crate::metrics::RouterMetrics;
|
use crate::metrics::RouterMetrics;
|
||||||
use crate::middleware::get_request_id;
|
use crate::openai_api_types::{ChatCompletionRequest, CompletionRequest, GenerateRequest};
|
||||||
use crate::policies::LoadBalancingPolicy;
|
use crate::policies::LoadBalancingPolicy;
|
||||||
use actix_web::http::header::{HeaderValue, CONTENT_TYPE};
|
use crate::routers::{RouterTrait, WorkerManagement};
|
||||||
use actix_web::{HttpRequest, HttpResponse};
|
use axum::{
|
||||||
use futures_util::{StreamExt, TryStreamExt};
|
body::Body,
|
||||||
|
extract::Request,
|
||||||
|
http::{header::CONTENT_TYPE, HeaderMap, HeaderValue, StatusCode},
|
||||||
|
response::{IntoResponse, Response},
|
||||||
|
Json,
|
||||||
|
};
|
||||||
|
use futures_util::StreamExt;
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::sync::{Arc, RwLock};
|
use std::sync::{Arc, RwLock};
|
||||||
use std::thread;
|
use std::thread;
|
||||||
use std::time::{Duration, Instant};
|
use std::time::{Duration, Instant};
|
||||||
|
use tokio_stream::wrappers::UnboundedReceiverStream;
|
||||||
use tracing::{debug, error, info, warn};
|
use tracing::{debug, error, info, warn};
|
||||||
|
pub fn copy_request_headers(req: &Request<Body>) -> Vec<(String, String)> {
|
||||||
pub fn copy_request_headers(req: &HttpRequest) -> Vec<(String, String)> {
|
|
||||||
req.headers()
|
req.headers()
|
||||||
.iter()
|
.iter()
|
||||||
.filter_map(|(name, value)| {
|
.filter_map(|(name, value)| {
|
||||||
@@ -239,154 +245,107 @@ impl Router {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn send_request(
|
pub async fn send_health_check(&self, client: &Client, worker_url: &str) -> Response {
|
||||||
&self,
|
let health_url = if self.dp_aware {
|
||||||
client: &reqwest::Client,
|
|
||||||
worker_url: &str,
|
|
||||||
route: &str,
|
|
||||||
req: &HttpRequest,
|
|
||||||
) -> HttpResponse {
|
|
||||||
let request_id = get_request_id(req);
|
|
||||||
let start = Instant::now();
|
|
||||||
|
|
||||||
let worker_url = if self.dp_aware {
|
|
||||||
// Need to extract the URL from "http://host:port@dp_rank"
|
// Need to extract the URL from "http://host:port@dp_rank"
|
||||||
let (worker_url_prefix, _dp_rank) = match Self::extract_dp_rank(worker_url) {
|
match Self::extract_dp_rank(worker_url) {
|
||||||
Ok(tup) => tup,
|
Ok((worker_url_prefix, _dp_rank)) => worker_url_prefix,
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
error!("Failed to extract dp_rank: {}", e);
|
error!("Failed to extract dp_rank for health check: {}", e);
|
||||||
return HttpResponse::InternalServerError().finish();
|
return (
|
||||||
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
|
format!("Failed to extract dp_rank: {}", e),
|
||||||
|
)
|
||||||
|
.into_response();
|
||||||
}
|
}
|
||||||
};
|
}
|
||||||
worker_url_prefix
|
|
||||||
} else {
|
} else {
|
||||||
worker_url
|
worker_url
|
||||||
};
|
};
|
||||||
|
|
||||||
let mut request_builder = client.get(format!("{}{}", worker_url, route));
|
let request_builder = client.get(format!("{}/health", health_url));
|
||||||
|
|
||||||
// Copy all headers from original request except for /health because it does not need authorization
|
|
||||||
if route != "/health" {
|
|
||||||
for (name, value) in copy_request_headers(req) {
|
|
||||||
// Skip Content-Type and Content-Length as .json() sets them
|
|
||||||
if name.to_lowercase() != "content-type" && name.to_lowercase() != "content-length"
|
|
||||||
{
|
|
||||||
request_builder = request_builder.header(name, value);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
let response = match request_builder.send().await {
|
let response = match request_builder.send().await {
|
||||||
Ok(res) => {
|
Ok(res) => {
|
||||||
let status = actix_web::http::StatusCode::from_u16(res.status().as_u16())
|
let status = StatusCode::from_u16(res.status().as_u16())
|
||||||
.unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR);
|
.unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
|
||||||
|
|
||||||
match res.bytes().await {
|
match res.bytes().await {
|
||||||
Ok(body) => HttpResponse::build(status).body(body.to_vec()),
|
Ok(body) => (status, body).into_response(),
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
error!(
|
error!(
|
||||||
request_id = %request_id,
|
worker_url = %health_url,
|
||||||
worker_url = %worker_url,
|
|
||||||
route = %route,
|
|
||||||
error = %e,
|
error = %e,
|
||||||
"Failed to read response body"
|
"Failed to read health response body"
|
||||||
);
|
);
|
||||||
HttpResponse::InternalServerError()
|
(
|
||||||
.body(format!("Failed to read response body: {}", e))
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
|
format!("Failed to read response body: {}", e),
|
||||||
|
)
|
||||||
|
.into_response()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
error!(
|
error!(
|
||||||
request_id = %request_id,
|
worker_url = %health_url,
|
||||||
worker_url = %worker_url,
|
|
||||||
route = %route,
|
|
||||||
error = %e,
|
error = %e,
|
||||||
"Failed to send request to worker"
|
"Failed to send health request to worker"
|
||||||
);
|
);
|
||||||
HttpResponse::InternalServerError().body(format!(
|
(
|
||||||
"Failed to send request to worker {}: {}",
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
worker_url, e
|
format!("Failed to send request to worker {}: {}", health_url, e),
|
||||||
))
|
)
|
||||||
|
.into_response()
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// Record request metrics
|
// Don't record metrics for health checks
|
||||||
if route != "/health" {
|
|
||||||
let duration = start.elapsed();
|
|
||||||
RouterMetrics::record_request(route);
|
|
||||||
RouterMetrics::record_request_duration(route, duration);
|
|
||||||
|
|
||||||
if !response.status().is_success() {
|
|
||||||
RouterMetrics::record_request_error(route, "request_failed");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
response
|
response
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn route_to_first(
|
// Helper method to proxy GET requests to the first available worker
|
||||||
|
async fn proxy_get_request(
|
||||||
&self,
|
&self,
|
||||||
client: &reqwest::Client,
|
client: &Client,
|
||||||
route: &str,
|
req: Request<Body>,
|
||||||
req: &HttpRequest,
|
endpoint: &str,
|
||||||
) -> HttpResponse {
|
) -> Response {
|
||||||
let request_id = get_request_id(req);
|
let headers = copy_request_headers(&req);
|
||||||
const MAX_REQUEST_RETRIES: u32 = 3;
|
|
||||||
const MAX_TOTAL_RETRIES: u32 = 6;
|
|
||||||
let mut total_retries = 0;
|
|
||||||
|
|
||||||
while total_retries < MAX_TOTAL_RETRIES {
|
match self.select_first_worker() {
|
||||||
match self.select_first_worker() {
|
Ok(worker_url) => {
|
||||||
Ok(worker_url) => {
|
let mut request_builder = client.get(format!("{}/{}", worker_url, endpoint));
|
||||||
let mut request_retries = 0;
|
for (name, value) in headers {
|
||||||
|
if name.to_lowercase() != "content-type"
|
||||||
// Try the same worker multiple times
|
&& name.to_lowercase() != "content-length"
|
||||||
while request_retries < MAX_REQUEST_RETRIES {
|
{
|
||||||
if total_retries >= 1 {
|
request_builder = request_builder.header(name, value);
|
||||||
info!("Retrying request after {} failed attempts", total_retries);
|
|
||||||
}
|
|
||||||
|
|
||||||
let response = self.send_request(client, &worker_url, route, req).await;
|
|
||||||
|
|
||||||
if response.status().is_success() {
|
|
||||||
return response;
|
|
||||||
} else {
|
|
||||||
// if the worker is healthy, it means the request is bad, so return the error response
|
|
||||||
let health_response =
|
|
||||||
self.send_request(client, &worker_url, "/health", req).await;
|
|
||||||
if health_response.status().is_success() {
|
|
||||||
return response;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
warn!(
|
|
||||||
request_id = %request_id,
|
|
||||||
route = %route,
|
|
||||||
worker_url = %worker_url,
|
|
||||||
attempt = request_retries + 1,
|
|
||||||
max_attempts = MAX_REQUEST_RETRIES,
|
|
||||||
"Request failed"
|
|
||||||
);
|
|
||||||
|
|
||||||
request_retries += 1;
|
|
||||||
total_retries += 1;
|
|
||||||
|
|
||||||
if request_retries == MAX_REQUEST_RETRIES {
|
|
||||||
warn!(
|
|
||||||
request_id = %request_id,
|
|
||||||
worker_url = %worker_url,
|
|
||||||
"Removing failed worker"
|
|
||||||
);
|
|
||||||
self.remove_failed_worker(&worker_url);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Err(e) => return HttpResponse::InternalServerError().body(e),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
HttpResponse::InternalServerError().body("All retry attempts failed")
|
match request_builder.send().await {
|
||||||
|
Ok(res) => {
|
||||||
|
let status = StatusCode::from_u16(res.status().as_u16())
|
||||||
|
.unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
|
||||||
|
match res.bytes().await {
|
||||||
|
Ok(body) => (status, body).into_response(),
|
||||||
|
Err(e) => (
|
||||||
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
|
format!("Failed to read response: {}", e),
|
||||||
|
)
|
||||||
|
.into_response(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(e) => (
|
||||||
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
|
format!("Request failed: {}", e),
|
||||||
|
)
|
||||||
|
.into_response(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(e) => (StatusCode::SERVICE_UNAVAILABLE, e).into_response(),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// New method to route typed requests directly
|
// New method to route typed requests directly
|
||||||
@@ -395,11 +354,10 @@ impl Router {
|
|||||||
>(
|
>(
|
||||||
&self,
|
&self,
|
||||||
client: &reqwest::Client,
|
client: &reqwest::Client,
|
||||||
req: &HttpRequest,
|
headers: Option<&HeaderMap>,
|
||||||
typed_req: &T,
|
typed_req: &T,
|
||||||
route: &str,
|
route: &str,
|
||||||
) -> HttpResponse {
|
) -> Response {
|
||||||
let request_id = get_request_id(req);
|
|
||||||
// Handle retries like the original implementation
|
// Handle retries like the original implementation
|
||||||
let start = Instant::now();
|
let start = Instant::now();
|
||||||
const MAX_REQUEST_RETRIES: u32 = 3;
|
const MAX_REQUEST_RETRIES: u32 = 3;
|
||||||
@@ -440,7 +398,7 @@ impl Router {
|
|||||||
let response = self
|
let response = self
|
||||||
.send_typed_request(
|
.send_typed_request(
|
||||||
client,
|
client,
|
||||||
req,
|
headers,
|
||||||
typed_req,
|
typed_req,
|
||||||
route,
|
route,
|
||||||
&worker_url,
|
&worker_url,
|
||||||
@@ -455,8 +413,7 @@ impl Router {
|
|||||||
return response;
|
return response;
|
||||||
} else {
|
} else {
|
||||||
// if the worker is healthy, it means the request is bad, so return the error response
|
// if the worker is healthy, it means the request is bad, so return the error response
|
||||||
let health_response =
|
let health_response = self.send_health_check(client, &worker_url).await;
|
||||||
self.send_request(client, &worker_url, "/health", req).await;
|
|
||||||
if health_response.status().is_success() {
|
if health_response.status().is_success() {
|
||||||
RouterMetrics::record_request_error(route, "request_failed");
|
RouterMetrics::record_request_error(route, "request_failed");
|
||||||
return response;
|
return response;
|
||||||
@@ -464,9 +421,11 @@ impl Router {
|
|||||||
}
|
}
|
||||||
|
|
||||||
warn!(
|
warn!(
|
||||||
request_id = %request_id,
|
|
||||||
"Generate request failed route={} worker_url={} attempt={} max_attempts={}",
|
"Generate request failed route={} worker_url={} attempt={} max_attempts={}",
|
||||||
route, worker_url, request_retries + 1, MAX_REQUEST_RETRIES
|
route,
|
||||||
|
worker_url,
|
||||||
|
request_retries + 1,
|
||||||
|
MAX_REQUEST_RETRIES
|
||||||
);
|
);
|
||||||
|
|
||||||
request_retries += 1;
|
request_retries += 1;
|
||||||
@@ -474,17 +433,21 @@ impl Router {
|
|||||||
|
|
||||||
if request_retries == MAX_REQUEST_RETRIES {
|
if request_retries == MAX_REQUEST_RETRIES {
|
||||||
warn!(
|
warn!(
|
||||||
request_id = %request_id,
|
"Removing failed worker after typed request failures worker_url={}",
|
||||||
"Removing failed worker after typed request failures worker_url={}", worker_url
|
worker_url
|
||||||
);
|
);
|
||||||
self.remove_failed_worker(&worker_url);
|
self.remove_worker(&worker_url);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
RouterMetrics::record_request_error(route, "request_failed");
|
RouterMetrics::record_request_error(route, "request_failed");
|
||||||
HttpResponse::InternalServerError().body("All retry attempts failed")
|
(
|
||||||
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
|
"All retry attempts failed",
|
||||||
|
)
|
||||||
|
.into_response()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Helper method to select worker from text using the policy
|
// Helper method to select worker from text using the policy
|
||||||
@@ -521,14 +484,13 @@ impl Router {
|
|||||||
async fn send_typed_request<T: serde::Serialize>(
|
async fn send_typed_request<T: serde::Serialize>(
|
||||||
&self,
|
&self,
|
||||||
client: &reqwest::Client,
|
client: &reqwest::Client,
|
||||||
req: &HttpRequest,
|
headers: Option<&HeaderMap>,
|
||||||
typed_req: &T,
|
typed_req: &T,
|
||||||
route: &str,
|
route: &str,
|
||||||
worker_url: &str,
|
worker_url: &str,
|
||||||
is_stream: bool,
|
is_stream: bool,
|
||||||
load_incremented: bool, // Whether load was incremented for this request
|
load_incremented: bool, // Whether load was incremented for this request
|
||||||
) -> HttpResponse {
|
) -> Response {
|
||||||
let request_id = get_request_id(req);
|
|
||||||
let start = Instant::now();
|
let start = Instant::now();
|
||||||
|
|
||||||
let mut request_builder = if self.dp_aware {
|
let mut request_builder = if self.dp_aware {
|
||||||
@@ -536,7 +498,11 @@ impl Router {
|
|||||||
Ok(tup) => tup,
|
Ok(tup) => tup,
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
error!("Failed to extract dp_rank: {}", e);
|
error!("Failed to extract dp_rank: {}", e);
|
||||||
return HttpResponse::InternalServerError().finish();
|
return (
|
||||||
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
|
format!("Failed to extract dp_rank: {}", e),
|
||||||
|
)
|
||||||
|
.into_response();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -544,8 +510,11 @@ impl Router {
|
|||||||
let mut json_val = match serde_json::to_value(typed_req) {
|
let mut json_val = match serde_json::to_value(typed_req) {
|
||||||
Ok(j) => j,
|
Ok(j) => j,
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
return HttpResponse::BadRequest()
|
return (
|
||||||
.body(format!("Convert into serde_json::Value failed: {}", e));
|
StatusCode::BAD_REQUEST,
|
||||||
|
format!("Convert into serde_json::Value failed: {}", e),
|
||||||
|
)
|
||||||
|
.into_response();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -560,8 +529,11 @@ impl Router {
|
|||||||
serde_json::to_string(&json_val).unwrap_or(String::from("ERR"))
|
serde_json::to_string(&json_val).unwrap_or(String::from("ERR"))
|
||||||
);
|
);
|
||||||
} else {
|
} else {
|
||||||
return HttpResponse::BadRequest()
|
return (
|
||||||
.body("Failed to insert the data_parallel_rank field into the request body");
|
StatusCode::BAD_REQUEST,
|
||||||
|
"Failed to insert the data_parallel_rank field into the request body",
|
||||||
|
)
|
||||||
|
.into_response();
|
||||||
}
|
}
|
||||||
|
|
||||||
client
|
client
|
||||||
@@ -573,11 +545,15 @@ impl Router {
|
|||||||
.json(typed_req) // Use json() directly with typed request
|
.json(typed_req) // Use json() directly with typed request
|
||||||
};
|
};
|
||||||
|
|
||||||
// Copy all headers from original request
|
// Copy all headers from original request if provided
|
||||||
for (name, value) in copy_request_headers(req) {
|
if let Some(headers) = headers {
|
||||||
// Skip Content-Type and Content-Length as .json() sets them
|
for (name, value) in headers {
|
||||||
if name.to_lowercase() != "content-type" && name.to_lowercase() != "content-length" {
|
// Skip Content-Type and Content-Length as .json() sets them
|
||||||
request_builder = request_builder.header(&name, &value);
|
if name.to_string().to_lowercase() != "content-type"
|
||||||
|
&& name.to_string().to_lowercase() != "content-length"
|
||||||
|
{
|
||||||
|
request_builder = request_builder.header(name, value);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -585,7 +561,6 @@ impl Router {
|
|||||||
Ok(res) => res,
|
Ok(res) => res,
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
error!(
|
error!(
|
||||||
request_id = %request_id,
|
|
||||||
"Failed to send typed request worker_url={} route={} error={}",
|
"Failed to send typed request worker_url={} route={} error={}",
|
||||||
worker_url, route, e
|
worker_url, route, e
|
||||||
);
|
);
|
||||||
@@ -600,20 +575,24 @@ impl Router {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return HttpResponse::InternalServerError().body(format!("Request failed: {}", e));
|
return (
|
||||||
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
|
format!("Request failed: {}", e),
|
||||||
|
)
|
||||||
|
.into_response();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
let status = actix_web::http::StatusCode::from_u16(res.status().as_u16())
|
let status = StatusCode::from_u16(res.status().as_u16())
|
||||||
.unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR);
|
.unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
|
||||||
|
|
||||||
if !is_stream {
|
if !is_stream {
|
||||||
// For non-streaming requests, get response first
|
// For non-streaming requests, get response first
|
||||||
let response = match res.bytes().await {
|
let response = match res.bytes().await {
|
||||||
Ok(body) => HttpResponse::build(status).body(body.to_vec()),
|
Ok(body) => (status, body).into_response(),
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
let error_msg = format!("Failed to get response body: {}", e);
|
let error_msg = format!("Failed to get response body: {}", e);
|
||||||
HttpResponse::InternalServerError().body(error_msg)
|
(StatusCode::INTERNAL_SERVER_ERROR, error_msg).into_response()
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -638,42 +617,86 @@ impl Router {
|
|||||||
let workers = Arc::clone(&self.workers);
|
let workers = Arc::clone(&self.workers);
|
||||||
let worker_url = worker_url.to_string();
|
let worker_url = worker_url.to_string();
|
||||||
|
|
||||||
HttpResponse::build(status)
|
let stream = res.bytes_stream();
|
||||||
.insert_header((CONTENT_TYPE, HeaderValue::from_static("text/event-stream")))
|
let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
|
||||||
.streaming(
|
|
||||||
res.bytes_stream()
|
// Spawn task to forward stream and detect completion
|
||||||
.map_err(|_| {
|
tokio::spawn(async move {
|
||||||
actix_web::error::ErrorInternalServerError("Failed to read stream")
|
let mut stream = stream;
|
||||||
})
|
while let Some(chunk) = stream.next().await {
|
||||||
.inspect(move |bytes| {
|
match chunk {
|
||||||
if let Ok(bytes) = bytes {
|
Ok(bytes) => {
|
||||||
if bytes
|
// Check for stream end marker
|
||||||
.as_ref()
|
if bytes
|
||||||
.windows(12)
|
.as_ref()
|
||||||
.any(|window| window == b"data: [DONE]")
|
.windows(12)
|
||||||
{
|
.any(|window| window == b"data: [DONE]")
|
||||||
if let Ok(workers_guard) = workers.read() {
|
{
|
||||||
if let Some(worker) =
|
if let Ok(workers_guard) = workers.read() {
|
||||||
workers_guard.iter().find(|w| w.url() == &worker_url)
|
if let Some(worker) =
|
||||||
{
|
workers_guard.iter().find(|w| w.url() == &worker_url)
|
||||||
worker.decrement_load();
|
{
|
||||||
RouterMetrics::set_running_requests(
|
worker.decrement_load();
|
||||||
&worker_url,
|
RouterMetrics::set_running_requests(
|
||||||
worker.load(),
|
&worker_url,
|
||||||
);
|
worker.load(),
|
||||||
}
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}),
|
if tx.send(Ok(bytes)).is_err() {
|
||||||
)
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
let _ = tx.send(Err(format!("Stream error: {}", e)));
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
let stream = UnboundedReceiverStream::new(rx);
|
||||||
|
let body = Body::from_stream(stream);
|
||||||
|
|
||||||
|
let mut response = Response::new(body);
|
||||||
|
*response.status_mut() = status;
|
||||||
|
response
|
||||||
|
.headers_mut()
|
||||||
|
.insert(CONTENT_TYPE, HeaderValue::from_static("text/event-stream"));
|
||||||
|
response
|
||||||
} else {
|
} else {
|
||||||
// For requests without load tracking, just stream
|
// For requests without load tracking, just stream
|
||||||
HttpResponse::build(status)
|
let stream = res.bytes_stream();
|
||||||
.insert_header((CONTENT_TYPE, HeaderValue::from_static("text/event-stream")))
|
let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
|
||||||
.streaming(res.bytes_stream().map_err(|_| {
|
|
||||||
actix_web::error::ErrorInternalServerError("Failed to read stream")
|
// Spawn task to forward stream
|
||||||
}))
|
tokio::spawn(async move {
|
||||||
|
let mut stream = stream;
|
||||||
|
while let Some(chunk) = stream.next().await {
|
||||||
|
match chunk {
|
||||||
|
Ok(bytes) => {
|
||||||
|
if tx.send(Ok(bytes)).is_err() {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
let _ = tx.send(Err(format!("Stream error: {}", e)));
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
let stream = UnboundedReceiverStream::new(rx);
|
||||||
|
let body = Body::from_stream(stream);
|
||||||
|
|
||||||
|
let mut response = Response::new(body);
|
||||||
|
*response.status_mut() = status;
|
||||||
|
response
|
||||||
|
.headers_mut()
|
||||||
|
.insert(CONTENT_TYPE, HeaderValue::from_static("text/event-stream"));
|
||||||
|
response
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -775,7 +798,6 @@ impl Router {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Remove all the worker(s) that match the URL prefix
|
|
||||||
pub fn remove_worker(&self, worker_url: &str) {
|
pub fn remove_worker(&self, worker_url: &str) {
|
||||||
if self.dp_aware {
|
if self.dp_aware {
|
||||||
// remove dp-aware workers in a prefix-matching fashion
|
// remove dp-aware workers in a prefix-matching fashion
|
||||||
@@ -844,28 +866,6 @@ impl Router {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Remove a specific failed worker; for internal usage
|
|
||||||
fn remove_failed_worker(&self, worker_url: &str) {
|
|
||||||
let mut workers_guard = self.workers.write().unwrap();
|
|
||||||
if let Some(index) = workers_guard.iter().position(|w| w.url() == worker_url) {
|
|
||||||
workers_guard.remove(index);
|
|
||||||
info!("Removed failed worker: {}", worker_url);
|
|
||||||
RouterMetrics::set_active_workers(workers_guard.len());
|
|
||||||
} else {
|
|
||||||
warn!("Worker {} not found, skipping removal", worker_url);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// If cache aware policy, remove the worker from the tree
|
|
||||||
if let Some(cache_aware) = self
|
|
||||||
.policy
|
|
||||||
.as_any()
|
|
||||||
.downcast_ref::<crate::policies::CacheAwarePolicy>()
|
|
||||||
{
|
|
||||||
cache_aware.remove_worker(worker_url);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn get_worker_load(&self, client: &reqwest::Client, worker_url: &str) -> Option<isize> {
|
async fn get_worker_load(&self, client: &reqwest::Client, worker_url: &str) -> Option<isize> {
|
||||||
let worker_url = if self.dp_aware {
|
let worker_url = if self.dp_aware {
|
||||||
// Need to extract the URL from "http://host:port@dp_rank"
|
// Need to extract the URL from "http://host:port@dp_rank"
|
||||||
@@ -1004,7 +1004,6 @@ impl Router {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
use crate::routers::{RouterTrait, WorkerManagement};
|
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use reqwest::Client;
|
use reqwest::Client;
|
||||||
|
|
||||||
@@ -1023,100 +1022,78 @@ impl WorkerManagement for Router {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[async_trait(?Send)]
|
#[async_trait]
|
||||||
impl RouterTrait for Router {
|
impl RouterTrait for Router {
|
||||||
fn as_any(&self) -> &dyn std::any::Any {
|
fn as_any(&self) -> &dyn std::any::Any {
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn health(&self, _client: &Client, _req: &HttpRequest) -> HttpResponse {
|
async fn health(&self, _client: &Client, _req: Request<Body>) -> Response {
|
||||||
// Check local health state of all workers (consistent with PD router)
|
let workers = self.workers.read().unwrap();
|
||||||
// Note: This uses cached health status from background health checks, not live checks
|
let unhealthy_servers: Vec<_> = workers
|
||||||
let mut all_healthy = true;
|
.iter()
|
||||||
let mut unhealthy_servers = Vec::new();
|
.filter(|w| !w.is_healthy())
|
||||||
|
.map(|w| w.url().to_string())
|
||||||
|
.collect();
|
||||||
|
|
||||||
for worker in self.workers.read().unwrap().iter() {
|
if unhealthy_servers.is_empty() {
|
||||||
if !worker.is_healthy() {
|
(StatusCode::OK, "All servers healthy").into_response()
|
||||||
all_healthy = false;
|
|
||||||
unhealthy_servers.push(worker.url().to_string());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if all_healthy {
|
|
||||||
HttpResponse::Ok().body("All servers healthy")
|
|
||||||
} else {
|
} else {
|
||||||
HttpResponse::ServiceUnavailable()
|
(
|
||||||
.body(format!("Unhealthy servers: {:?}", unhealthy_servers))
|
StatusCode::SERVICE_UNAVAILABLE,
|
||||||
|
format!("Unhealthy servers: {:?}", unhealthy_servers),
|
||||||
|
)
|
||||||
|
.into_response()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn health_generate(&self, client: &Client, req: &HttpRequest) -> HttpResponse {
|
async fn health_generate(&self, client: &Client, req: Request<Body>) -> Response {
|
||||||
// Test model generation capability by sending to first available worker
|
self.proxy_get_request(client, req, "health_generate").await
|
||||||
// Note: This endpoint actually causes the model to generate a token, so we only test one worker
|
|
||||||
self.route_to_first(client, "/health_generate", req).await
|
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn get_server_info(&self, client: &Client, req: &HttpRequest) -> HttpResponse {
|
async fn get_server_info(&self, client: &Client, req: Request<Body>) -> Response {
|
||||||
self.route_to_first(client, "/get_server_info", req).await
|
self.proxy_get_request(client, req, "get_server_info").await
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn get_models(&self, client: &Client, req: &HttpRequest) -> HttpResponse {
|
async fn get_models(&self, client: &Client, req: Request<Body>) -> Response {
|
||||||
self.route_to_first(client, "/v1/models", req).await
|
self.proxy_get_request(client, req, "v1/models").await
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn get_model_info(&self, client: &Client, req: &HttpRequest) -> HttpResponse {
|
async fn get_model_info(&self, client: &Client, req: Request<Body>) -> Response {
|
||||||
self.route_to_first(client, "/get_model_info", req).await
|
self.proxy_get_request(client, req, "get_model_info").await
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn route_generate(
|
async fn route_generate(
|
||||||
&self,
|
&self,
|
||||||
client: &Client,
|
client: &Client,
|
||||||
req: &HttpRequest,
|
headers: Option<&HeaderMap>,
|
||||||
body: serde_json::Value,
|
body: &GenerateRequest,
|
||||||
) -> HttpResponse {
|
) -> Response {
|
||||||
// Convert JSON to typed request
|
self.route_typed_request(client, headers, body, "/generate")
|
||||||
match serde_json::from_value::<crate::openai_api_types::GenerateRequest>(body) {
|
.await
|
||||||
Ok(typed_req) => {
|
|
||||||
self.route_typed_request(client, req, &typed_req, "/generate")
|
|
||||||
.await
|
|
||||||
}
|
|
||||||
Err(e) => HttpResponse::BadRequest().body(format!("Invalid request: {}", e)),
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn route_chat(
|
async fn route_chat(
|
||||||
&self,
|
&self,
|
||||||
client: &Client,
|
client: &Client,
|
||||||
req: &HttpRequest,
|
headers: Option<&HeaderMap>,
|
||||||
body: serde_json::Value,
|
body: &ChatCompletionRequest,
|
||||||
) -> HttpResponse {
|
) -> Response {
|
||||||
// Convert JSON to typed request
|
self.route_typed_request(client, headers, body, "/v1/chat/completions")
|
||||||
match serde_json::from_value::<crate::openai_api_types::ChatCompletionRequest>(body) {
|
.await
|
||||||
Ok(typed_req) => {
|
|
||||||
self.route_typed_request(client, req, &typed_req, "/v1/chat/completions")
|
|
||||||
.await
|
|
||||||
}
|
|
||||||
Err(e) => HttpResponse::BadRequest().body(format!("Invalid request: {}", e)),
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn route_completion(
|
async fn route_completion(
|
||||||
&self,
|
&self,
|
||||||
client: &Client,
|
client: &Client,
|
||||||
req: &HttpRequest,
|
headers: Option<&HeaderMap>,
|
||||||
body: serde_json::Value,
|
body: &CompletionRequest,
|
||||||
) -> HttpResponse {
|
) -> Response {
|
||||||
// Convert JSON to typed request
|
self.route_typed_request(client, headers, body, "/v1/completions")
|
||||||
match serde_json::from_value::<crate::openai_api_types::CompletionRequest>(body) {
|
.await
|
||||||
Ok(typed_req) => {
|
|
||||||
self.route_typed_request(client, req, &typed_req, "/v1/completions")
|
|
||||||
.await
|
|
||||||
}
|
|
||||||
Err(e) => HttpResponse::BadRequest().body(format!("Invalid request: {}", e)),
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn flush_cache(&self, client: &Client) -> HttpResponse {
|
async fn flush_cache(&self, client: &Client) -> Response {
|
||||||
// Get all worker URLs
|
// Get all worker URLs
|
||||||
let worker_urls = self.get_worker_urls();
|
let worker_urls = self.get_worker_urls();
|
||||||
|
|
||||||
@@ -1129,7 +1106,11 @@ impl RouterTrait for Router {
|
|||||||
Ok(tup) => tup,
|
Ok(tup) => tup,
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
error!("Failed to extract dp_rank: {}", e);
|
error!("Failed to extract dp_rank: {}", e);
|
||||||
return HttpResponse::InternalServerError().finish();
|
return (
|
||||||
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
|
format!("Failed to extract dp_rank: {}", e),
|
||||||
|
)
|
||||||
|
.into_response();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
worker_url_prefix
|
worker_url_prefix
|
||||||
@@ -1151,13 +1132,17 @@ impl RouterTrait for Router {
|
|||||||
});
|
});
|
||||||
|
|
||||||
if all_success {
|
if all_success {
|
||||||
HttpResponse::Ok().body("Cache flushed on all servers")
|
(StatusCode::OK, "Cache flushed on all servers").into_response()
|
||||||
} else {
|
} else {
|
||||||
HttpResponse::InternalServerError().body("Cache flush failed on one or more servers")
|
(
|
||||||
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
|
"Cache flush failed on one or more servers",
|
||||||
|
)
|
||||||
|
.into_response()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn get_worker_loads(&self, client: &Client) -> HttpResponse {
|
async fn get_worker_loads(&self, client: &Client) -> Response {
|
||||||
let urls = self.get_worker_urls();
|
let urls = self.get_worker_urls();
|
||||||
let mut loads = Vec::new();
|
let mut loads = Vec::new();
|
||||||
|
|
||||||
@@ -1170,16 +1155,17 @@ impl RouterTrait for Router {
|
|||||||
}));
|
}));
|
||||||
}
|
}
|
||||||
|
|
||||||
HttpResponse::Ok().json(serde_json::json!({
|
Json(serde_json::json!({
|
||||||
"workers": loads
|
"workers": loads
|
||||||
}))
|
}))
|
||||||
|
.into_response()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn router_type(&self) -> &'static str {
|
fn router_type(&self) -> &'static str {
|
||||||
"regular"
|
"regular"
|
||||||
}
|
}
|
||||||
|
|
||||||
fn readiness(&self) -> HttpResponse {
|
fn readiness(&self) -> Response {
|
||||||
// Regular router is ready if it has at least one healthy worker
|
// Regular router is ready if it has at least one healthy worker
|
||||||
let healthy_count = self
|
let healthy_count = self
|
||||||
.workers
|
.workers
|
||||||
@@ -1190,17 +1176,22 @@ impl RouterTrait for Router {
|
|||||||
.count();
|
.count();
|
||||||
|
|
||||||
if healthy_count > 0 {
|
if healthy_count > 0 {
|
||||||
HttpResponse::Ok().json(serde_json::json!({
|
Json(serde_json::json!({
|
||||||
"status": "ready",
|
"status": "ready",
|
||||||
"healthy_workers": healthy_count,
|
"healthy_workers": healthy_count,
|
||||||
"total_workers": self.workers.read().unwrap().len()
|
"total_workers": self.workers.read().unwrap().len()
|
||||||
}))
|
}))
|
||||||
|
.into_response()
|
||||||
} else {
|
} else {
|
||||||
HttpResponse::ServiceUnavailable().json(serde_json::json!({
|
(
|
||||||
"status": "not_ready",
|
StatusCode::SERVICE_UNAVAILABLE,
|
||||||
"reason": "no healthy workers available",
|
Json(serde_json::json!({
|
||||||
"total_workers": self.workers.read().unwrap().len()
|
"status": "not_ready",
|
||||||
}))
|
"reason": "no healthy workers available",
|
||||||
|
"total_workers": self.workers.read().unwrap().len()
|
||||||
|
})),
|
||||||
|
)
|
||||||
|
.into_response()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,285 +1,169 @@
|
|||||||
use crate::config::RouterConfig;
|
use crate::config::RouterConfig;
|
||||||
use crate::logging::{self, LoggingConfig};
|
use crate::logging::{self, LoggingConfig};
|
||||||
use crate::metrics::{self, PrometheusConfig};
|
use crate::metrics::{self, PrometheusConfig};
|
||||||
use crate::middleware::{get_request_id, RequestIdMiddleware};
|
|
||||||
use crate::openai_api_types::{ChatCompletionRequest, CompletionRequest, GenerateRequest};
|
use crate::openai_api_types::{ChatCompletionRequest, CompletionRequest, GenerateRequest};
|
||||||
use crate::routers::{RouterFactory, RouterTrait};
|
use crate::routers::{RouterFactory, RouterTrait};
|
||||||
use crate::service_discovery::{start_service_discovery, ServiceDiscoveryConfig};
|
use crate::service_discovery::{start_service_discovery, ServiceDiscoveryConfig};
|
||||||
use actix_web::{
|
use axum::{
|
||||||
error, get, post, web, App, Error, HttpRequest, HttpResponse, HttpServer, Responder,
|
extract::{Query, Request, State},
|
||||||
|
http::StatusCode,
|
||||||
|
response::{IntoResponse, Response},
|
||||||
|
routing::{get, post},
|
||||||
|
Json, Router,
|
||||||
};
|
};
|
||||||
use futures_util::StreamExt;
|
|
||||||
use reqwest::Client;
|
use reqwest::Client;
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::sync::atomic::{AtomicBool, Ordering};
|
use std::sync::atomic::{AtomicBool, Ordering};
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
|
use tokio::net::TcpListener;
|
||||||
|
use tokio::signal;
|
||||||
use tokio::spawn;
|
use tokio::spawn;
|
||||||
use tracing::{error, info, warn, Level};
|
use tracing::{error, info, warn, Level};
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Clone)]
|
||||||
pub struct AppState {
|
pub struct AppState {
|
||||||
router: Arc<dyn RouterTrait>,
|
pub router: Arc<dyn RouterTrait>,
|
||||||
client: Client,
|
pub client: Client,
|
||||||
|
pub _concurrency_limiter: Arc<tokio::sync::Semaphore>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl AppState {
|
impl AppState {
|
||||||
pub fn new(router_config: RouterConfig, client: Client) -> Result<Self, String> {
|
pub fn new(
|
||||||
// Use RouterFactory to create the appropriate router type
|
router_config: RouterConfig,
|
||||||
|
client: Client,
|
||||||
|
max_concurrent_requests: usize,
|
||||||
|
) -> Result<Self, String> {
|
||||||
let router = RouterFactory::create_router(&router_config)?;
|
let router = RouterFactory::create_router(&router_config)?;
|
||||||
|
|
||||||
// Convert Box<dyn RouterTrait> to Arc<dyn RouterTrait>
|
|
||||||
let router = Arc::from(router);
|
let router = Arc::from(router);
|
||||||
|
let concurrency_limiter = Arc::new(tokio::sync::Semaphore::new(max_concurrent_requests));
|
||||||
Ok(Self { router, client })
|
Ok(Self {
|
||||||
|
router,
|
||||||
|
client,
|
||||||
|
_concurrency_limiter: concurrency_limiter,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn sink_handler(_req: HttpRequest, mut payload: web::Payload) -> Result<HttpResponse, Error> {
|
// Fallback handler for unmatched routes
|
||||||
// Drain the payload
|
async fn sink_handler() -> Response {
|
||||||
while let Some(chunk) = payload.next().await {
|
StatusCode::NOT_FOUND.into_response()
|
||||||
if let Err(err) = chunk {
|
|
||||||
println!("Error while draining payload: {:?}", err);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Ok(HttpResponse::NotFound().finish())
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Custom error handler for JSON payload errors.
|
// Health check endpoints
|
||||||
fn json_error_handler(err: error::JsonPayloadError, req: &HttpRequest) -> Error {
|
async fn liveness(State(state): State<Arc<AppState>>) -> Response {
|
||||||
let request_id = get_request_id(req);
|
state.router.liveness()
|
||||||
match &err {
|
|
||||||
error::JsonPayloadError::OverflowKnownLength { length, limit } => {
|
|
||||||
error!(
|
|
||||||
request_id = %request_id,
|
|
||||||
"Payload too large length={} limit={}", length, limit
|
|
||||||
);
|
|
||||||
error::ErrorPayloadTooLarge(format!(
|
|
||||||
"Payload too large: {} bytes exceeds limit of {} bytes",
|
|
||||||
length, limit
|
|
||||||
))
|
|
||||||
}
|
|
||||||
error::JsonPayloadError::Overflow { limit } => {
|
|
||||||
error!(
|
|
||||||
request_id = %request_id,
|
|
||||||
"Payload overflow limit={}", limit
|
|
||||||
);
|
|
||||||
error::ErrorPayloadTooLarge(format!("Payload exceeds limit of {} bytes", limit))
|
|
||||||
}
|
|
||||||
_ => {
|
|
||||||
error!(
|
|
||||||
request_id = %request_id,
|
|
||||||
"Invalid JSON payload error={}", err
|
|
||||||
);
|
|
||||||
error::ErrorBadRequest(format!("Invalid JSON payload: {}", err))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[get("/liveness")]
|
async fn readiness(State(state): State<Arc<AppState>>) -> Response {
|
||||||
async fn liveness(_req: HttpRequest, data: web::Data<AppState>) -> impl Responder {
|
state.router.readiness()
|
||||||
data.router.liveness()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[get("/readiness")]
|
async fn health(State(state): State<Arc<AppState>>, req: Request) -> Response {
|
||||||
async fn readiness(_req: HttpRequest, data: web::Data<AppState>) -> impl Responder {
|
state.router.health(&state.client, req).await
|
||||||
data.router.readiness()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[get("/health")]
|
async fn health_generate(State(state): State<Arc<AppState>>, req: Request) -> Response {
|
||||||
async fn health(req: HttpRequest, data: web::Data<AppState>) -> impl Responder {
|
state.router.health_generate(&state.client, req).await
|
||||||
data.router.health(&data.client, &req).await
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[get("/health_generate")]
|
async fn get_server_info(State(state): State<Arc<AppState>>, req: Request) -> Response {
|
||||||
async fn health_generate(req: HttpRequest, data: web::Data<AppState>) -> impl Responder {
|
state.router.get_server_info(&state.client, req).await
|
||||||
data.router.health_generate(&data.client, &req).await
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[get("/get_server_info")]
|
async fn v1_models(State(state): State<Arc<AppState>>, req: Request) -> Response {
|
||||||
async fn get_server_info(req: HttpRequest, data: web::Data<AppState>) -> impl Responder {
|
state.router.get_models(&state.client, req).await
|
||||||
data.router.get_server_info(&data.client, &req).await
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[get("/v1/models")]
|
async fn get_model_info(State(state): State<Arc<AppState>>, req: Request) -> Response {
|
||||||
async fn v1_models(req: HttpRequest, data: web::Data<AppState>) -> impl Responder {
|
state.router.get_model_info(&state.client, req).await
|
||||||
data.router.get_models(&data.client, &req).await
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[get("/get_model_info")]
|
// Generation endpoints
|
||||||
async fn get_model_info(req: HttpRequest, data: web::Data<AppState>) -> impl Responder {
|
// The RouterTrait now accepts optional headers and typed body directly
|
||||||
data.router.get_model_info(&data.client, &req).await
|
|
||||||
}
|
|
||||||
|
|
||||||
#[post("/generate")]
|
|
||||||
async fn generate(
|
async fn generate(
|
||||||
req: HttpRequest,
|
State(state): State<Arc<AppState>>,
|
||||||
body: web::Json<GenerateRequest>,
|
headers: http::HeaderMap,
|
||||||
state: web::Data<AppState>,
|
Json(body): Json<GenerateRequest>,
|
||||||
) -> Result<HttpResponse, Error> {
|
) -> Response {
|
||||||
let request_id = get_request_id(&req);
|
state
|
||||||
info!(
|
|
||||||
request_id = %request_id,
|
|
||||||
"Received generate request method=\"POST\" path=\"/generate\""
|
|
||||||
);
|
|
||||||
|
|
||||||
let json_body = serde_json::to_value(body.into_inner()).map_err(|e| {
|
|
||||||
error!(
|
|
||||||
request_id = %request_id,
|
|
||||||
"Failed to parse generate request body error={}", e
|
|
||||||
);
|
|
||||||
error::ErrorBadRequest(format!("Invalid JSON: {}", e))
|
|
||||||
})?;
|
|
||||||
|
|
||||||
Ok(state
|
|
||||||
.router
|
.router
|
||||||
.route_generate(&state.client, &req, json_body)
|
.route_generate(&state.client, Some(&headers), &body)
|
||||||
.await)
|
.await
|
||||||
}
|
}
|
||||||
|
|
||||||
#[post("/v1/chat/completions")]
|
|
||||||
async fn v1_chat_completions(
|
async fn v1_chat_completions(
|
||||||
req: HttpRequest,
|
State(state): State<Arc<AppState>>,
|
||||||
body: web::Json<ChatCompletionRequest>,
|
headers: http::HeaderMap,
|
||||||
state: web::Data<AppState>,
|
Json(body): Json<ChatCompletionRequest>,
|
||||||
) -> Result<HttpResponse, Error> {
|
) -> Response {
|
||||||
let request_id = get_request_id(&req);
|
state
|
||||||
info!(
|
|
||||||
request_id = %request_id,
|
|
||||||
"Received chat completion request method=\"POST\" path=\"/v1/chat/completions\""
|
|
||||||
);
|
|
||||||
|
|
||||||
let json_body = serde_json::to_value(body.into_inner()).map_err(|e| {
|
|
||||||
error!(
|
|
||||||
request_id = %request_id,
|
|
||||||
"Failed to parse chat completion request body error={}", e
|
|
||||||
);
|
|
||||||
error::ErrorBadRequest(format!("Invalid JSON: {}", e))
|
|
||||||
})?;
|
|
||||||
|
|
||||||
Ok(state
|
|
||||||
.router
|
.router
|
||||||
.route_chat(&state.client, &req, json_body)
|
.route_chat(&state.client, Some(&headers), &body)
|
||||||
.await)
|
.await
|
||||||
}
|
}
|
||||||
|
|
||||||
#[post("/v1/completions")]
|
|
||||||
async fn v1_completions(
|
async fn v1_completions(
|
||||||
req: HttpRequest,
|
State(state): State<Arc<AppState>>,
|
||||||
body: web::Json<CompletionRequest>,
|
headers: http::HeaderMap,
|
||||||
state: web::Data<AppState>,
|
Json(body): Json<CompletionRequest>,
|
||||||
) -> Result<HttpResponse, Error> {
|
) -> Response {
|
||||||
let request_id = get_request_id(&req);
|
state
|
||||||
info!(
|
|
||||||
request_id = %request_id,
|
|
||||||
"Received completion request method=\"POST\" path=\"/v1/completions\""
|
|
||||||
);
|
|
||||||
|
|
||||||
let json_body = serde_json::to_value(body.into_inner()).map_err(|e| {
|
|
||||||
error!(
|
|
||||||
request_id = %request_id,
|
|
||||||
"Failed to parse completion request body error={}", e
|
|
||||||
);
|
|
||||||
error::ErrorBadRequest(format!("Invalid JSON: {}", e))
|
|
||||||
})?;
|
|
||||||
|
|
||||||
Ok(state
|
|
||||||
.router
|
.router
|
||||||
.route_completion(&state.client, &req, json_body)
|
.route_completion(&state.client, Some(&headers), &body)
|
||||||
.await)
|
.await
|
||||||
}
|
}
|
||||||
|
|
||||||
#[post("/add_worker")]
|
// Worker management endpoints
|
||||||
async fn add_worker(
|
async fn add_worker(
|
||||||
req: HttpRequest,
|
State(state): State<Arc<AppState>>,
|
||||||
query: web::Query<HashMap<String, String>>,
|
Query(params): Query<HashMap<String, String>>,
|
||||||
data: web::Data<AppState>,
|
) -> Response {
|
||||||
) -> impl Responder {
|
let worker_url = match params.get("url") {
|
||||||
let request_id = get_request_id(&req);
|
|
||||||
|
|
||||||
let worker_url = match query.get("url") {
|
|
||||||
Some(url) => url.to_string(),
|
Some(url) => url.to_string(),
|
||||||
None => {
|
None => {
|
||||||
warn!(
|
return (
|
||||||
request_id = %request_id,
|
StatusCode::BAD_REQUEST,
|
||||||
"Add worker request missing URL parameter"
|
"Worker URL required. Provide 'url' query parameter",
|
||||||
);
|
)
|
||||||
return HttpResponse::BadRequest()
|
.into_response();
|
||||||
.body("Worker URL required. Provide 'url' query parameter");
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
info!(
|
match state.router.add_worker(&worker_url).await {
|
||||||
request_id = %request_id,
|
Ok(message) => (StatusCode::OK, message).into_response(),
|
||||||
worker_url = %worker_url,
|
Err(error) => (StatusCode::BAD_REQUEST, error).into_response(),
|
||||||
"Adding worker"
|
|
||||||
);
|
|
||||||
|
|
||||||
match data.router.add_worker(&worker_url).await {
|
|
||||||
Ok(message) => {
|
|
||||||
info!(
|
|
||||||
request_id = %request_id,
|
|
||||||
worker_url = %worker_url,
|
|
||||||
"Successfully added worker"
|
|
||||||
);
|
|
||||||
HttpResponse::Ok().body(message)
|
|
||||||
}
|
|
||||||
Err(error) => {
|
|
||||||
error!(
|
|
||||||
request_id = %request_id,
|
|
||||||
worker_url = %worker_url,
|
|
||||||
error = %error,
|
|
||||||
"Failed to add worker"
|
|
||||||
);
|
|
||||||
HttpResponse::BadRequest().body(error)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[get("/list_workers")]
|
async fn list_workers(State(state): State<Arc<AppState>>) -> Response {
|
||||||
async fn list_workers(data: web::Data<AppState>) -> impl Responder {
|
let worker_list = state.router.get_worker_urls();
|
||||||
let worker_list = data.router.get_worker_urls();
|
Json(serde_json::json!({ "urls": worker_list })).into_response()
|
||||||
HttpResponse::Ok().json(serde_json::json!({ "urls": worker_list }))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[post("/remove_worker")]
|
|
||||||
async fn remove_worker(
|
async fn remove_worker(
|
||||||
req: HttpRequest,
|
State(state): State<Arc<AppState>>,
|
||||||
query: web::Query<HashMap<String, String>>,
|
Query(params): Query<HashMap<String, String>>,
|
||||||
data: web::Data<AppState>,
|
) -> Response {
|
||||||
) -> impl Responder {
|
let worker_url = match params.get("url") {
|
||||||
let request_id = get_request_id(&req);
|
|
||||||
|
|
||||||
let worker_url = match query.get("url") {
|
|
||||||
Some(url) => url.to_string(),
|
Some(url) => url.to_string(),
|
||||||
None => {
|
None => return StatusCode::BAD_REQUEST.into_response(),
|
||||||
warn!(
|
|
||||||
request_id = %request_id,
|
|
||||||
"Remove worker request missing URL parameter"
|
|
||||||
);
|
|
||||||
return HttpResponse::BadRequest().finish();
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
info!(
|
state.router.remove_worker(&worker_url);
|
||||||
request_id = %request_id,
|
(
|
||||||
worker_url = %worker_url,
|
StatusCode::OK,
|
||||||
"Removing worker"
|
format!("Successfully removed worker: {}", worker_url),
|
||||||
);
|
)
|
||||||
|
.into_response()
|
||||||
data.router.remove_worker(&worker_url);
|
|
||||||
HttpResponse::Ok().body(format!("Successfully removed worker: {}", worker_url))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[post("/flush_cache")]
|
async fn flush_cache(State(state): State<Arc<AppState>>, _req: Request) -> Response {
|
||||||
async fn flush_cache(_req: HttpRequest, data: web::Data<AppState>) -> impl Responder {
|
state.router.flush_cache(&state.client).await
|
||||||
data.router.flush_cache(&data.client).await
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[get("/get_loads")]
|
async fn get_loads(State(state): State<Arc<AppState>>, _req: Request) -> Response {
|
||||||
async fn get_loads(_req: HttpRequest, data: web::Data<AppState>) -> impl Responder {
|
state.router.get_worker_loads(&state.client).await
|
||||||
data.router.get_worker_loads(&data.client).await
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct ServerConfig {
|
pub struct ServerConfig {
|
||||||
@@ -295,7 +179,58 @@ pub struct ServerConfig {
|
|||||||
pub request_id_headers: Option<Vec<String>>,
|
pub request_id_headers: Option<Vec<String>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn startup(config: ServerConfig) -> std::io::Result<()> {
|
/// Build the Axum application with all routes and middleware
|
||||||
|
pub fn build_app(
|
||||||
|
app_state: Arc<AppState>,
|
||||||
|
max_payload_size: usize,
|
||||||
|
request_id_headers: Vec<String>,
|
||||||
|
cors_allowed_origins: Vec<String>,
|
||||||
|
) -> Router {
|
||||||
|
// Create routes
|
||||||
|
let protected_routes = Router::new()
|
||||||
|
.route("/generate", post(generate))
|
||||||
|
.route("/v1/chat/completions", post(v1_chat_completions))
|
||||||
|
.route("/v1/completions", post(v1_completions));
|
||||||
|
|
||||||
|
let public_routes = Router::new()
|
||||||
|
.route("/liveness", get(liveness))
|
||||||
|
.route("/readiness", get(readiness))
|
||||||
|
.route("/health", get(health))
|
||||||
|
.route("/health_generate", get(health_generate))
|
||||||
|
.route("/v1/models", get(v1_models))
|
||||||
|
.route("/get_model_info", get(get_model_info))
|
||||||
|
.route("/get_server_info", get(get_server_info));
|
||||||
|
|
||||||
|
let admin_routes = Router::new()
|
||||||
|
.route("/add_worker", post(add_worker))
|
||||||
|
.route("/remove_worker", post(remove_worker))
|
||||||
|
.route("/list_workers", get(list_workers))
|
||||||
|
.route("/flush_cache", post(flush_cache))
|
||||||
|
.route("/get_loads", get(get_loads));
|
||||||
|
|
||||||
|
// Build app with all routes and middleware
|
||||||
|
Router::new()
|
||||||
|
.merge(protected_routes)
|
||||||
|
.merge(public_routes)
|
||||||
|
.merge(admin_routes)
|
||||||
|
// Request body size limiting
|
||||||
|
.layer(tower_http::limit::RequestBodyLimitLayer::new(
|
||||||
|
max_payload_size,
|
||||||
|
))
|
||||||
|
// Request ID layer - must be added AFTER logging layer in the code
|
||||||
|
// so it executes BEFORE logging layer at runtime (layers execute bottom-up)
|
||||||
|
.layer(crate::middleware::RequestIdLayer::new(request_id_headers))
|
||||||
|
// Custom logging layer that can now see request IDs from extensions
|
||||||
|
.layer(crate::middleware::create_logging_layer())
|
||||||
|
// CORS (should be outermost)
|
||||||
|
.layer(create_cors_layer(cors_allowed_origins))
|
||||||
|
// Fallback
|
||||||
|
.fallback(sink_handler)
|
||||||
|
// State - apply last to get Router<Arc<AppState>>
|
||||||
|
.with_state(app_state)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Error>> {
|
||||||
// Only initialize logging if not already done (for Python bindings support)
|
// Only initialize logging if not already done (for Python bindings support)
|
||||||
static LOGGING_INITIALIZED: AtomicBool = AtomicBool::new(false);
|
static LOGGING_INITIALIZED: AtomicBool = AtomicBool::new(false);
|
||||||
|
|
||||||
@@ -338,14 +273,20 @@ pub async fn startup(config: ServerConfig) -> std::io::Result<()> {
|
|||||||
|
|
||||||
let client = Client::builder()
|
let client = Client::builder()
|
||||||
.pool_idle_timeout(Some(Duration::from_secs(50)))
|
.pool_idle_timeout(Some(Duration::from_secs(50)))
|
||||||
.timeout(Duration::from_secs(config.request_timeout_secs)) // Use configurable timeout
|
.pool_max_idle_per_host(100) // Increase from default of 1 to allow more concurrent connections
|
||||||
|
.timeout(Duration::from_secs(config.request_timeout_secs))
|
||||||
|
.connect_timeout(Duration::from_secs(10)) // Separate connection timeout
|
||||||
|
.tcp_nodelay(true)
|
||||||
|
.tcp_keepalive(Some(Duration::from_secs(30))) // Keep connections alive
|
||||||
.build()
|
.build()
|
||||||
.expect("Failed to create HTTP client");
|
.expect("Failed to create HTTP client");
|
||||||
|
|
||||||
let app_state_init = AppState::new(config.router_config.clone(), client.clone())
|
let app_state = Arc::new(AppState::new(
|
||||||
.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?;
|
config.router_config.clone(),
|
||||||
let router_arc = Arc::clone(&app_state_init.router);
|
client.clone(),
|
||||||
let app_state = web::Data::new(app_state_init);
|
config.router_config.max_concurrent_requests,
|
||||||
|
)?);
|
||||||
|
let router_arc = Arc::clone(&app_state.router);
|
||||||
|
|
||||||
// Start the service discovery if enabled
|
// Start the service discovery if enabled
|
||||||
if let Some(service_discovery_config) = config.service_discovery_config {
|
if let Some(service_discovery_config) = config.service_discovery_config {
|
||||||
@@ -383,36 +324,83 @@ pub async fn startup(config: ServerConfig) -> std::io::Result<()> {
|
|||||||
]
|
]
|
||||||
});
|
});
|
||||||
|
|
||||||
HttpServer::new(move || {
|
// Build the application
|
||||||
let request_id_middleware = RequestIdMiddleware::new(request_id_headers.clone());
|
let app = build_app(
|
||||||
|
app_state,
|
||||||
|
config.max_payload_size,
|
||||||
|
request_id_headers,
|
||||||
|
config.router_config.cors_allowed_origins.clone(),
|
||||||
|
);
|
||||||
|
|
||||||
App::new()
|
// Create TCP listener - use the configured host
|
||||||
.wrap(request_id_middleware)
|
let addr = format!("{}:{}", config.host, config.port);
|
||||||
.app_data(app_state.clone())
|
let listener = TcpListener::bind(&addr).await?;
|
||||||
.app_data(
|
|
||||||
web::JsonConfig::default()
|
// Start server with graceful shutdown
|
||||||
.limit(config.max_payload_size)
|
info!("Starting server on {}", addr);
|
||||||
.error_handler(json_error_handler),
|
|
||||||
)
|
// Serve the application with graceful shutdown
|
||||||
.app_data(web::PayloadConfig::default().limit(config.max_payload_size))
|
axum::serve(listener, app)
|
||||||
.service(generate)
|
.with_graceful_shutdown(shutdown_signal())
|
||||||
.service(v1_chat_completions)
|
.await
|
||||||
.service(v1_completions)
|
.map_err(|e| Box::new(e) as Box<dyn std::error::Error>)?;
|
||||||
.service(v1_models)
|
|
||||||
.service(get_model_info)
|
Ok(())
|
||||||
.service(liveness)
|
}
|
||||||
.service(readiness)
|
|
||||||
.service(health)
|
// Graceful shutdown handler
|
||||||
.service(health_generate)
|
async fn shutdown_signal() {
|
||||||
.service(get_server_info)
|
let ctrl_c = async {
|
||||||
.service(add_worker)
|
signal::ctrl_c()
|
||||||
.service(remove_worker)
|
.await
|
||||||
.service(list_workers)
|
.expect("failed to install Ctrl+C handler");
|
||||||
.service(flush_cache)
|
};
|
||||||
.service(get_loads)
|
|
||||||
.default_service(web::route().to(sink_handler))
|
#[cfg(unix)]
|
||||||
})
|
let terminate = async {
|
||||||
.bind_auto_h2c((config.host, config.port))?
|
signal::unix::signal(signal::unix::SignalKind::terminate())
|
||||||
.run()
|
.expect("failed to install signal handler")
|
||||||
.await
|
.recv()
|
||||||
|
.await;
|
||||||
|
};
|
||||||
|
|
||||||
|
#[cfg(not(unix))]
|
||||||
|
let terminate = std::future::pending::<()>();
|
||||||
|
|
||||||
|
tokio::select! {
|
||||||
|
_ = ctrl_c => {
|
||||||
|
info!("Received Ctrl+C, starting graceful shutdown");
|
||||||
|
},
|
||||||
|
_ = terminate => {
|
||||||
|
info!("Received terminate signal, starting graceful shutdown");
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// CORS Layer Creation
|
||||||
|
fn create_cors_layer(allowed_origins: Vec<String>) -> tower_http::cors::CorsLayer {
|
||||||
|
use tower_http::cors::Any;
|
||||||
|
|
||||||
|
let cors = if allowed_origins.is_empty() {
|
||||||
|
// Allow all origins if none specified
|
||||||
|
tower_http::cors::CorsLayer::new()
|
||||||
|
.allow_origin(Any)
|
||||||
|
.allow_methods(Any)
|
||||||
|
.allow_headers(Any)
|
||||||
|
.expose_headers(Any)
|
||||||
|
} else {
|
||||||
|
// Restrict to specific origins
|
||||||
|
let origins: Vec<http::HeaderValue> = allowed_origins
|
||||||
|
.into_iter()
|
||||||
|
.filter_map(|origin| origin.parse().ok())
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
tower_http::cors::CorsLayer::new()
|
||||||
|
.allow_origin(origins)
|
||||||
|
.allow_methods([http::Method::GET, http::Method::POST, http::Method::OPTIONS])
|
||||||
|
.allow_headers([http::header::CONTENT_TYPE, http::header::AUTHORIZATION])
|
||||||
|
.expose_headers([http::header::HeaderName::from_static("x-request-id")])
|
||||||
|
};
|
||||||
|
|
||||||
|
cors.max_age(Duration::from_secs(3600))
|
||||||
}
|
}
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -1,62 +1,2 @@
|
|||||||
pub mod mock_worker;
|
pub mod mock_worker;
|
||||||
|
pub mod test_app;
|
||||||
use actix_web::web;
|
|
||||||
use reqwest::Client;
|
|
||||||
use sglang_router_rs::config::{PolicyConfig, RouterConfig, RoutingMode};
|
|
||||||
use sglang_router_rs::server::AppState;
|
|
||||||
|
|
||||||
/// Helper function to create test router configuration
|
|
||||||
pub fn create_test_config(worker_urls: Vec<String>) -> RouterConfig {
|
|
||||||
RouterConfig {
|
|
||||||
mode: RoutingMode::Regular { worker_urls },
|
|
||||||
policy: PolicyConfig::Random,
|
|
||||||
host: "127.0.0.1".to_string(),
|
|
||||||
port: 3001,
|
|
||||||
max_payload_size: 256 * 1024 * 1024, // 256MB
|
|
||||||
request_timeout_secs: 600,
|
|
||||||
worker_startup_timeout_secs: 300,
|
|
||||||
worker_startup_check_interval_secs: 10,
|
|
||||||
dp_aware: false,
|
|
||||||
api_key: None,
|
|
||||||
discovery: None,
|
|
||||||
metrics: None,
|
|
||||||
log_dir: None,
|
|
||||||
log_level: None,
|
|
||||||
request_id_headers: None,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Helper function to create test router configuration with no health check
|
|
||||||
pub fn create_test_config_no_workers() -> RouterConfig {
|
|
||||||
RouterConfig {
|
|
||||||
mode: RoutingMode::Regular {
|
|
||||||
worker_urls: vec![],
|
|
||||||
}, // Empty to skip health check
|
|
||||||
policy: PolicyConfig::Random,
|
|
||||||
host: "127.0.0.1".to_string(),
|
|
||||||
port: 3001,
|
|
||||||
max_payload_size: 256 * 1024 * 1024, // 256MB
|
|
||||||
request_timeout_secs: 600,
|
|
||||||
worker_startup_timeout_secs: 0, // No wait
|
|
||||||
worker_startup_check_interval_secs: 10,
|
|
||||||
dp_aware: false,
|
|
||||||
api_key: None,
|
|
||||||
discovery: None,
|
|
||||||
metrics: None,
|
|
||||||
log_dir: None,
|
|
||||||
log_level: None,
|
|
||||||
request_id_headers: None,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Helper function to create test app state
|
|
||||||
pub async fn create_test_app_state(config: RouterConfig) -> Result<web::Data<AppState>, String> {
|
|
||||||
// Create a non-blocking client
|
|
||||||
let client = Client::builder()
|
|
||||||
.timeout(std::time::Duration::from_secs(config.request_timeout_secs))
|
|
||||||
.build()
|
|
||||||
.map_err(|e| e.to_string())?;
|
|
||||||
|
|
||||||
let app_state = AppState::new(config, client)?;
|
|
||||||
Ok(web::Data::new(app_state))
|
|
||||||
}
|
|
||||||
|
|||||||
42
sgl-router/tests/common/test_app.rs
Normal file
42
sgl-router/tests/common/test_app.rs
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
use axum::Router;
|
||||||
|
use reqwest::Client;
|
||||||
|
use sglang_router_rs::{
|
||||||
|
config::RouterConfig,
|
||||||
|
routers::RouterTrait,
|
||||||
|
server::{build_app, AppState},
|
||||||
|
};
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
/// Create a test Axum application using the actual server's build_app function
|
||||||
|
pub fn create_test_app(
|
||||||
|
router: Arc<dyn RouterTrait>,
|
||||||
|
client: Client,
|
||||||
|
router_config: &RouterConfig,
|
||||||
|
) -> Router {
|
||||||
|
// Create AppState with the test router
|
||||||
|
let app_state = Arc::new(AppState {
|
||||||
|
router,
|
||||||
|
client,
|
||||||
|
_concurrency_limiter: Arc::new(tokio::sync::Semaphore::new(
|
||||||
|
router_config.max_concurrent_requests,
|
||||||
|
)),
|
||||||
|
});
|
||||||
|
|
||||||
|
// Configure request ID headers (use defaults if not specified)
|
||||||
|
let request_id_headers = router_config.request_id_headers.clone().unwrap_or_else(|| {
|
||||||
|
vec![
|
||||||
|
"x-request-id".to_string(),
|
||||||
|
"x-correlation-id".to_string(),
|
||||||
|
"x-trace-id".to_string(),
|
||||||
|
"request-id".to_string(),
|
||||||
|
]
|
||||||
|
});
|
||||||
|
|
||||||
|
// Use the actual server's build_app function
|
||||||
|
build_app(
|
||||||
|
app_state,
|
||||||
|
router_config.max_payload_size,
|
||||||
|
request_id_headers,
|
||||||
|
router_config.cors_allowed_origins.clone(),
|
||||||
|
)
|
||||||
|
}
|
||||||
@@ -1,43 +1,27 @@
|
|||||||
mod common;
|
mod common;
|
||||||
|
|
||||||
use actix_web::{http::StatusCode, rt::System, test as actix_test, web, App};
|
|
||||||
use common::mock_worker::{HealthStatus, MockWorker, MockWorkerConfig, WorkerType};
|
use common::mock_worker::{HealthStatus, MockWorker, MockWorkerConfig, WorkerType};
|
||||||
use reqwest::Client;
|
use reqwest::Client;
|
||||||
use serde_json::json;
|
use serde_json::json;
|
||||||
use sglang_router_rs::config::{PolicyConfig, RouterConfig, RoutingMode};
|
use sglang_router_rs::config::{PolicyConfig, RouterConfig, RoutingMode};
|
||||||
use sglang_router_rs::server::{
|
use sglang_router_rs::routers::{RouterFactory, RouterTrait};
|
||||||
add_worker, generate, v1_chat_completions, v1_completions, AppState,
|
use std::sync::Arc;
|
||||||
};
|
|
||||||
|
|
||||||
/// Test context for request type testing
|
/// Test context that manages mock workers
|
||||||
struct RequestTestContext {
|
struct TestContext {
|
||||||
workers: Vec<MockWorker>,
|
workers: Vec<MockWorker>,
|
||||||
app_state: web::Data<AppState>,
|
router: Arc<dyn RouterTrait>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl RequestTestContext {
|
impl TestContext {
|
||||||
async fn new(worker_configs: Vec<MockWorkerConfig>) -> Self {
|
async fn new(worker_configs: Vec<MockWorkerConfig>) -> Self {
|
||||||
let mut workers = Vec::new();
|
let mut config = RouterConfig {
|
||||||
let mut worker_urls = Vec::new();
|
|
||||||
|
|
||||||
// Start mock workers
|
|
||||||
for config in worker_configs {
|
|
||||||
let mut worker = MockWorker::new(config);
|
|
||||||
let url = worker.start().await.unwrap();
|
|
||||||
worker_urls.push(url);
|
|
||||||
workers.push(worker);
|
|
||||||
}
|
|
||||||
|
|
||||||
tokio::time::sleep(tokio::time::Duration::from_millis(200)).await;
|
|
||||||
|
|
||||||
// Create router config
|
|
||||||
let config = RouterConfig {
|
|
||||||
mode: RoutingMode::Regular {
|
mode: RoutingMode::Regular {
|
||||||
worker_urls: vec![],
|
worker_urls: vec![],
|
||||||
},
|
},
|
||||||
policy: PolicyConfig::Random,
|
policy: PolicyConfig::Random,
|
||||||
host: "127.0.0.1".to_string(),
|
host: "127.0.0.1".to_string(),
|
||||||
port: 3006,
|
port: 3003,
|
||||||
max_payload_size: 256 * 1024 * 1024,
|
max_payload_size: 256 * 1024 * 1024,
|
||||||
request_timeout_secs: 600,
|
request_timeout_secs: 600,
|
||||||
worker_startup_timeout_secs: 1,
|
worker_startup_timeout_secs: 1,
|
||||||
@@ -49,528 +33,348 @@ impl RequestTestContext {
|
|||||||
log_dir: None,
|
log_dir: None,
|
||||||
log_level: None,
|
log_level: None,
|
||||||
request_id_headers: None,
|
request_id_headers: None,
|
||||||
|
max_concurrent_requests: 64,
|
||||||
|
cors_allowed_origins: vec![],
|
||||||
};
|
};
|
||||||
|
|
||||||
let client = Client::builder()
|
let mut workers = Vec::new();
|
||||||
.timeout(std::time::Duration::from_secs(config.request_timeout_secs))
|
let mut worker_urls = Vec::new();
|
||||||
.build()
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
let app_state = AppState::new(config, client).unwrap();
|
for worker_config in worker_configs {
|
||||||
let app_state = web::Data::new(app_state);
|
let mut worker = MockWorker::new(worker_config);
|
||||||
|
let url = worker.start().await.unwrap();
|
||||||
// Add workers via HTTP API
|
worker_urls.push(url);
|
||||||
let app =
|
workers.push(worker);
|
||||||
actix_test::init_service(App::new().app_data(app_state.clone()).service(add_worker))
|
|
||||||
.await;
|
|
||||||
|
|
||||||
for url in &worker_urls {
|
|
||||||
let req = actix_test::TestRequest::post()
|
|
||||||
.uri(&format!("/add_worker?url={}", url))
|
|
||||||
.to_request();
|
|
||||||
let resp = actix_test::call_service(&app, req).await;
|
|
||||||
assert!(resp.status().is_success());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
tokio::time::sleep(tokio::time::Duration::from_millis(500)).await;
|
if !workers.is_empty() {
|
||||||
|
tokio::time::sleep(tokio::time::Duration::from_millis(200)).await;
|
||||||
|
}
|
||||||
|
|
||||||
Self { workers, app_state }
|
config.mode = RoutingMode::Regular { worker_urls };
|
||||||
}
|
|
||||||
|
|
||||||
async fn create_app(
|
let router = tokio::task::spawn_blocking(move || RouterFactory::create_router(&config))
|
||||||
&self,
|
.await
|
||||||
) -> impl actix_web::dev::Service<
|
.unwrap()
|
||||||
actix_http::Request,
|
.unwrap();
|
||||||
Response = actix_web::dev::ServiceResponse,
|
let router = Arc::from(router);
|
||||||
Error = actix_web::Error,
|
|
||||||
> {
|
if !workers.is_empty() {
|
||||||
actix_test::init_service(
|
tokio::time::sleep(tokio::time::Duration::from_millis(500)).await;
|
||||||
App::new()
|
}
|
||||||
.app_data(self.app_state.clone())
|
|
||||||
.service(generate)
|
Self { workers, router }
|
||||||
.service(v1_chat_completions)
|
|
||||||
.service(v1_completions),
|
|
||||||
)
|
|
||||||
.await
|
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn shutdown(mut self) {
|
async fn shutdown(mut self) {
|
||||||
|
// Small delay to ensure any pending operations complete
|
||||||
|
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
|
||||||
|
|
||||||
for worker in &mut self.workers {
|
for worker in &mut self.workers {
|
||||||
worker.stop().await;
|
worker.stop().await;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Another small delay to ensure cleanup completes
|
||||||
|
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn make_request(
|
||||||
|
&self,
|
||||||
|
endpoint: &str,
|
||||||
|
body: serde_json::Value,
|
||||||
|
) -> Result<serde_json::Value, String> {
|
||||||
|
let client = Client::new();
|
||||||
|
|
||||||
|
// Get any worker URL for testing
|
||||||
|
let worker_urls = self.router.get_worker_urls();
|
||||||
|
if worker_urls.is_empty() {
|
||||||
|
return Err("No available workers".to_string());
|
||||||
|
}
|
||||||
|
|
||||||
|
let worker_url = &worker_urls[0];
|
||||||
|
|
||||||
|
let response = client
|
||||||
|
.post(&format!("{}{}", worker_url, endpoint))
|
||||||
|
.json(&body)
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
.map_err(|e| format!("Request failed: {}", e))?;
|
||||||
|
|
||||||
|
if !response.status().is_success() {
|
||||||
|
return Err(format!("Request failed with status: {}", response.status()));
|
||||||
|
}
|
||||||
|
|
||||||
|
response
|
||||||
|
.json::<serde_json::Value>()
|
||||||
|
.await
|
||||||
|
.map_err(|e| format!("Failed to parse response: {}", e))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod generate_input_format_tests {
|
mod request_format_tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
|
|
||||||
#[test]
|
#[tokio::test]
|
||||||
fn test_generate_with_text_input() {
|
async fn test_generate_request_formats() {
|
||||||
System::new().block_on(async {
|
let ctx = TestContext::new(vec![MockWorkerConfig {
|
||||||
let ctx = RequestTestContext::new(vec![MockWorkerConfig {
|
port: 19001,
|
||||||
port: 21001,
|
worker_type: WorkerType::Regular,
|
||||||
worker_type: WorkerType::Regular,
|
health_status: HealthStatus::Healthy,
|
||||||
health_status: HealthStatus::Healthy,
|
response_delay_ms: 0,
|
||||||
response_delay_ms: 0,
|
fail_rate: 0.0,
|
||||||
fail_rate: 0.0,
|
}])
|
||||||
}])
|
.await;
|
||||||
.await;
|
|
||||||
|
|
||||||
let app = ctx.create_app().await;
|
// Test 1: Basic text request
|
||||||
|
let payload = json!({
|
||||||
// Standard text input
|
"text": "Hello, world!",
|
||||||
let payload = json!({
|
"stream": false
|
||||||
"text": "Hello world",
|
|
||||||
"stream": false
|
|
||||||
});
|
|
||||||
|
|
||||||
let req = actix_test::TestRequest::post()
|
|
||||||
.uri("/generate")
|
|
||||||
.set_json(&payload)
|
|
||||||
.to_request();
|
|
||||||
|
|
||||||
let resp = actix_test::call_service(&app, req).await;
|
|
||||||
assert_eq!(resp.status(), StatusCode::OK);
|
|
||||||
|
|
||||||
let body: serde_json::Value = actix_test::read_body_json(resp).await;
|
|
||||||
assert!(body.get("text").is_some());
|
|
||||||
|
|
||||||
ctx.shutdown().await;
|
|
||||||
});
|
});
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
let result = ctx.make_request("/generate", payload).await;
|
||||||
fn test_generate_with_prompt_input() {
|
assert!(result.is_ok());
|
||||||
System::new().block_on(async {
|
|
||||||
let ctx = RequestTestContext::new(vec![MockWorkerConfig {
|
|
||||||
port: 21002,
|
|
||||||
worker_type: WorkerType::Regular,
|
|
||||||
health_status: HealthStatus::Healthy,
|
|
||||||
response_delay_ms: 0,
|
|
||||||
fail_rate: 0.0,
|
|
||||||
}])
|
|
||||||
.await;
|
|
||||||
|
|
||||||
let app = ctx.create_app().await;
|
// Test 2: Request with sampling parameters
|
||||||
|
let payload = json!({
|
||||||
// Prompt input (alternative to text)
|
"text": "Tell me a story",
|
||||||
let payload = json!({
|
"sampling_params": {
|
||||||
"prompt": "Once upon a time",
|
|
||||||
"stream": false
|
|
||||||
});
|
|
||||||
|
|
||||||
let req = actix_test::TestRequest::post()
|
|
||||||
.uri("/generate")
|
|
||||||
.set_json(&payload)
|
|
||||||
.to_request();
|
|
||||||
|
|
||||||
let resp = actix_test::call_service(&app, req).await;
|
|
||||||
assert_eq!(resp.status(), StatusCode::OK);
|
|
||||||
|
|
||||||
ctx.shutdown().await;
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_generate_with_input_ids() {
|
|
||||||
System::new().block_on(async {
|
|
||||||
let ctx = RequestTestContext::new(vec![MockWorkerConfig {
|
|
||||||
port: 21003,
|
|
||||||
worker_type: WorkerType::Regular,
|
|
||||||
health_status: HealthStatus::Healthy,
|
|
||||||
response_delay_ms: 0,
|
|
||||||
fail_rate: 0.0,
|
|
||||||
}])
|
|
||||||
.await;
|
|
||||||
|
|
||||||
let app = ctx.create_app().await;
|
|
||||||
|
|
||||||
// Input IDs (tokenized input)
|
|
||||||
let payload = json!({
|
|
||||||
"input_ids": [1, 2, 3, 4, 5],
|
|
||||||
"stream": false
|
|
||||||
});
|
|
||||||
|
|
||||||
let req = actix_test::TestRequest::post()
|
|
||||||
.uri("/generate")
|
|
||||||
.set_json(&payload)
|
|
||||||
.to_request();
|
|
||||||
|
|
||||||
let resp = actix_test::call_service(&app, req).await;
|
|
||||||
assert_eq!(resp.status(), StatusCode::OK);
|
|
||||||
|
|
||||||
ctx.shutdown().await;
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_generate_with_all_parameters() {
|
|
||||||
System::new().block_on(async {
|
|
||||||
let ctx = RequestTestContext::new(vec![MockWorkerConfig {
|
|
||||||
port: 21004,
|
|
||||||
worker_type: WorkerType::Regular,
|
|
||||||
health_status: HealthStatus::Healthy,
|
|
||||||
response_delay_ms: 0,
|
|
||||||
fail_rate: 0.0,
|
|
||||||
}])
|
|
||||||
.await;
|
|
||||||
|
|
||||||
let app = ctx.create_app().await;
|
|
||||||
|
|
||||||
// All generation parameters
|
|
||||||
let payload = json!({
|
|
||||||
"text": "Complete this",
|
|
||||||
"temperature": 0.7,
|
"temperature": 0.7,
|
||||||
"top_p": 0.9,
|
|
||||||
"top_k": 50,
|
|
||||||
"max_new_tokens": 100,
|
"max_new_tokens": 100,
|
||||||
"min_new_tokens": 10,
|
"top_p": 0.9
|
||||||
"frequency_penalty": 0.5,
|
},
|
||||||
"presence_penalty": 0.3,
|
"stream": false
|
||||||
"repetition_penalty": 1.1,
|
|
||||||
"stop": [".", "!", "?"],
|
|
||||||
"stream": false
|
|
||||||
});
|
|
||||||
|
|
||||||
let req = actix_test::TestRequest::post()
|
|
||||||
.uri("/generate")
|
|
||||||
.set_json(&payload)
|
|
||||||
.to_request();
|
|
||||||
|
|
||||||
let resp = actix_test::call_service(&app, req).await;
|
|
||||||
assert_eq!(resp.status(), StatusCode::OK);
|
|
||||||
|
|
||||||
ctx.shutdown().await;
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(test)]
|
|
||||||
mod chat_completion_format_tests {
|
|
||||||
use super::*;
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_chat_with_system_message() {
|
|
||||||
System::new().block_on(async {
|
|
||||||
let ctx = RequestTestContext::new(vec![MockWorkerConfig {
|
|
||||||
port: 21010,
|
|
||||||
worker_type: WorkerType::Regular,
|
|
||||||
health_status: HealthStatus::Healthy,
|
|
||||||
response_delay_ms: 0,
|
|
||||||
fail_rate: 0.0,
|
|
||||||
}])
|
|
||||||
.await;
|
|
||||||
|
|
||||||
let app = ctx.create_app().await;
|
|
||||||
|
|
||||||
let payload = json!({
|
|
||||||
"model": "test-model",
|
|
||||||
"messages": [
|
|
||||||
{"role": "system", "content": "You are a helpful assistant."},
|
|
||||||
{"role": "user", "content": "Hello!"}
|
|
||||||
]
|
|
||||||
});
|
|
||||||
|
|
||||||
let req = actix_test::TestRequest::post()
|
|
||||||
.uri("/v1/chat/completions")
|
|
||||||
.set_json(&payload)
|
|
||||||
.to_request();
|
|
||||||
|
|
||||||
let resp = actix_test::call_service(&app, req).await;
|
|
||||||
assert_eq!(resp.status(), StatusCode::OK);
|
|
||||||
|
|
||||||
ctx.shutdown().await;
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
// Note: Function calling and tools tests are commented out because
|
|
||||||
// they require special handling in the mock worker that's not implemented yet.
|
|
||||||
// In production, these would be forwarded to the actual model.
|
|
||||||
|
|
||||||
// #[test]
|
|
||||||
// fn test_chat_with_function_calling() {
|
|
||||||
// // Test would go here when mock worker supports function calling
|
|
||||||
// }
|
|
||||||
|
|
||||||
// #[test]
|
|
||||||
// fn test_chat_with_tools() {
|
|
||||||
// // Test would go here when mock worker supports tools
|
|
||||||
// }
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_chat_with_response_format() {
|
|
||||||
System::new().block_on(async {
|
|
||||||
let ctx = RequestTestContext::new(vec![MockWorkerConfig {
|
|
||||||
port: 21013,
|
|
||||||
worker_type: WorkerType::Regular,
|
|
||||||
health_status: HealthStatus::Healthy,
|
|
||||||
response_delay_ms: 0,
|
|
||||||
fail_rate: 0.0,
|
|
||||||
}])
|
|
||||||
.await;
|
|
||||||
|
|
||||||
let app = ctx.create_app().await;
|
|
||||||
|
|
||||||
let payload = json!({
|
|
||||||
"model": "test-model",
|
|
||||||
"messages": [
|
|
||||||
{"role": "user", "content": "Return JSON"}
|
|
||||||
],
|
|
||||||
"response_format": {
|
|
||||||
"type": "json_object"
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
let req = actix_test::TestRequest::post()
|
|
||||||
.uri("/v1/chat/completions")
|
|
||||||
.set_json(&payload)
|
|
||||||
.to_request();
|
|
||||||
|
|
||||||
let resp = actix_test::call_service(&app, req).await;
|
|
||||||
assert_eq!(resp.status(), StatusCode::OK);
|
|
||||||
|
|
||||||
ctx.shutdown().await;
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(test)]
|
|
||||||
mod completion_format_tests {
|
|
||||||
use super::*;
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_completion_with_single_prompt() {
|
|
||||||
System::new().block_on(async {
|
|
||||||
let ctx = RequestTestContext::new(vec![MockWorkerConfig {
|
|
||||||
port: 21020,
|
|
||||||
worker_type: WorkerType::Regular,
|
|
||||||
health_status: HealthStatus::Healthy,
|
|
||||||
response_delay_ms: 0,
|
|
||||||
fail_rate: 0.0,
|
|
||||||
}])
|
|
||||||
.await;
|
|
||||||
|
|
||||||
let app = ctx.create_app().await;
|
|
||||||
|
|
||||||
let payload = json!({
|
|
||||||
"model": "test-model",
|
|
||||||
"prompt": "Once upon a time",
|
|
||||||
"max_tokens": 50
|
|
||||||
});
|
|
||||||
|
|
||||||
let req = actix_test::TestRequest::post()
|
|
||||||
.uri("/v1/completions")
|
|
||||||
.set_json(&payload)
|
|
||||||
.to_request();
|
|
||||||
|
|
||||||
let resp = actix_test::call_service(&app, req).await;
|
|
||||||
assert_eq!(resp.status(), StatusCode::OK);
|
|
||||||
|
|
||||||
let body: serde_json::Value = actix_test::read_body_json(resp).await;
|
|
||||||
assert!(body.get("choices").is_some());
|
|
||||||
|
|
||||||
ctx.shutdown().await;
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_completion_with_batch_prompts() {
|
|
||||||
System::new().block_on(async {
|
|
||||||
let ctx = RequestTestContext::new(vec![MockWorkerConfig {
|
|
||||||
port: 21021,
|
|
||||||
worker_type: WorkerType::Regular,
|
|
||||||
health_status: HealthStatus::Healthy,
|
|
||||||
response_delay_ms: 0,
|
|
||||||
fail_rate: 0.0,
|
|
||||||
}])
|
|
||||||
.await;
|
|
||||||
|
|
||||||
let app = ctx.create_app().await;
|
|
||||||
|
|
||||||
let payload = json!({
|
|
||||||
"model": "test-model",
|
|
||||||
"prompt": ["First prompt", "Second prompt", "Third prompt"],
|
|
||||||
"max_tokens": 30
|
|
||||||
});
|
|
||||||
|
|
||||||
let req = actix_test::TestRequest::post()
|
|
||||||
.uri("/v1/completions")
|
|
||||||
.set_json(&payload)
|
|
||||||
.to_request();
|
|
||||||
|
|
||||||
let resp = actix_test::call_service(&app, req).await;
|
|
||||||
assert_eq!(resp.status(), StatusCode::OK);
|
|
||||||
|
|
||||||
ctx.shutdown().await;
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_completion_with_echo() {
|
|
||||||
System::new().block_on(async {
|
|
||||||
let ctx = RequestTestContext::new(vec![MockWorkerConfig {
|
|
||||||
port: 21022,
|
|
||||||
worker_type: WorkerType::Regular,
|
|
||||||
health_status: HealthStatus::Healthy,
|
|
||||||
response_delay_ms: 0,
|
|
||||||
fail_rate: 0.0,
|
|
||||||
}])
|
|
||||||
.await;
|
|
||||||
|
|
||||||
let app = ctx.create_app().await;
|
|
||||||
|
|
||||||
let payload = json!({
|
|
||||||
"model": "test-model",
|
|
||||||
"prompt": "Echo this prompt",
|
|
||||||
"echo": true,
|
|
||||||
"max_tokens": 20
|
|
||||||
});
|
|
||||||
|
|
||||||
let req = actix_test::TestRequest::post()
|
|
||||||
.uri("/v1/completions")
|
|
||||||
.set_json(&payload)
|
|
||||||
.to_request();
|
|
||||||
|
|
||||||
let resp = actix_test::call_service(&app, req).await;
|
|
||||||
assert_eq!(resp.status(), StatusCode::OK);
|
|
||||||
|
|
||||||
ctx.shutdown().await;
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_completion_with_logprobs() {
|
|
||||||
System::new().block_on(async {
|
|
||||||
let ctx = RequestTestContext::new(vec![MockWorkerConfig {
|
|
||||||
port: 21023,
|
|
||||||
worker_type: WorkerType::Regular,
|
|
||||||
health_status: HealthStatus::Healthy,
|
|
||||||
response_delay_ms: 0,
|
|
||||||
fail_rate: 0.0,
|
|
||||||
}])
|
|
||||||
.await;
|
|
||||||
|
|
||||||
let app = ctx.create_app().await;
|
|
||||||
|
|
||||||
let payload = json!({
|
|
||||||
"model": "test-model",
|
|
||||||
"prompt": "Calculate probability",
|
|
||||||
"logprobs": 5,
|
|
||||||
"max_tokens": 10
|
|
||||||
});
|
|
||||||
|
|
||||||
let req = actix_test::TestRequest::post()
|
|
||||||
.uri("/v1/completions")
|
|
||||||
.set_json(&payload)
|
|
||||||
.to_request();
|
|
||||||
|
|
||||||
let resp = actix_test::call_service(&app, req).await;
|
|
||||||
assert_eq!(resp.status(), StatusCode::OK);
|
|
||||||
|
|
||||||
ctx.shutdown().await;
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_completion_with_suffix() {
|
|
||||||
System::new().block_on(async {
|
|
||||||
let ctx = RequestTestContext::new(vec![MockWorkerConfig {
|
|
||||||
port: 21024,
|
|
||||||
worker_type: WorkerType::Regular,
|
|
||||||
health_status: HealthStatus::Healthy,
|
|
||||||
response_delay_ms: 0,
|
|
||||||
fail_rate: 0.0,
|
|
||||||
}])
|
|
||||||
.await;
|
|
||||||
|
|
||||||
let app = ctx.create_app().await;
|
|
||||||
|
|
||||||
let payload = json!({
|
|
||||||
"model": "test-model",
|
|
||||||
"prompt": "Insert text here: ",
|
|
||||||
"suffix": " and continue from here.",
|
|
||||||
"max_tokens": 20
|
|
||||||
});
|
|
||||||
|
|
||||||
let req = actix_test::TestRequest::post()
|
|
||||||
.uri("/v1/completions")
|
|
||||||
.set_json(&payload)
|
|
||||||
.to_request();
|
|
||||||
|
|
||||||
let resp = actix_test::call_service(&app, req).await;
|
|
||||||
assert_eq!(resp.status(), StatusCode::OK);
|
|
||||||
|
|
||||||
ctx.shutdown().await;
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(test)]
|
|
||||||
mod stop_sequence_tests {
|
|
||||||
use super::*;
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_stop_sequences_array() {
|
|
||||||
System::new().block_on(async {
|
|
||||||
let ctx = RequestTestContext::new(vec![MockWorkerConfig {
|
|
||||||
port: 21030,
|
|
||||||
worker_type: WorkerType::Regular,
|
|
||||||
health_status: HealthStatus::Healthy,
|
|
||||||
response_delay_ms: 0,
|
|
||||||
fail_rate: 0.0,
|
|
||||||
}])
|
|
||||||
.await;
|
|
||||||
|
|
||||||
let app = ctx.create_app().await;
|
|
||||||
|
|
||||||
let payload = json!({
|
|
||||||
"text": "Generate until stop",
|
|
||||||
"stop": [".", "!", "?", "\n"],
|
|
||||||
"stream": false
|
|
||||||
});
|
|
||||||
|
|
||||||
let req = actix_test::TestRequest::post()
|
|
||||||
.uri("/generate")
|
|
||||||
.set_json(&payload)
|
|
||||||
.to_request();
|
|
||||||
|
|
||||||
let resp = actix_test::call_service(&app, req).await;
|
|
||||||
assert_eq!(resp.status(), StatusCode::OK);
|
|
||||||
|
|
||||||
ctx.shutdown().await;
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_stop_sequences_string() {
|
|
||||||
System::new().block_on(async {
|
|
||||||
let ctx = RequestTestContext::new(vec![MockWorkerConfig {
|
|
||||||
port: 21031,
|
|
||||||
worker_type: WorkerType::Regular,
|
|
||||||
health_status: HealthStatus::Healthy,
|
|
||||||
response_delay_ms: 0,
|
|
||||||
fail_rate: 0.0,
|
|
||||||
}])
|
|
||||||
.await;
|
|
||||||
|
|
||||||
let app = ctx.create_app().await;
|
|
||||||
|
|
||||||
let payload = json!({
|
|
||||||
"text": "Generate until stop",
|
|
||||||
"stop": "\n\n",
|
|
||||||
"stream": false
|
|
||||||
});
|
|
||||||
|
|
||||||
let req = actix_test::TestRequest::post()
|
|
||||||
.uri("/generate")
|
|
||||||
.set_json(&payload)
|
|
||||||
.to_request();
|
|
||||||
|
|
||||||
let resp = actix_test::call_service(&app, req).await;
|
|
||||||
assert_eq!(resp.status(), StatusCode::OK);
|
|
||||||
|
|
||||||
ctx.shutdown().await;
|
|
||||||
});
|
});
|
||||||
|
|
||||||
|
let result = ctx.make_request("/generate", payload).await;
|
||||||
|
assert!(result.is_ok());
|
||||||
|
|
||||||
|
// Test 3: Request with input_ids
|
||||||
|
let payload = json!({
|
||||||
|
"input_ids": [1, 2, 3, 4, 5],
|
||||||
|
"sampling_params": {
|
||||||
|
"temperature": 0.0,
|
||||||
|
"max_new_tokens": 50
|
||||||
|
},
|
||||||
|
"stream": false
|
||||||
|
});
|
||||||
|
|
||||||
|
let result = ctx.make_request("/generate", payload).await;
|
||||||
|
assert!(result.is_ok());
|
||||||
|
|
||||||
|
ctx.shutdown().await;
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_v1_chat_completions_formats() {
|
||||||
|
let ctx = TestContext::new(vec![MockWorkerConfig {
|
||||||
|
port: 19002,
|
||||||
|
worker_type: WorkerType::Regular,
|
||||||
|
health_status: HealthStatus::Healthy,
|
||||||
|
response_delay_ms: 0,
|
||||||
|
fail_rate: 0.0,
|
||||||
|
}])
|
||||||
|
.await;
|
||||||
|
|
||||||
|
// Test 1: Basic chat completion
|
||||||
|
let payload = json!({
|
||||||
|
"model": "test-model",
|
||||||
|
"messages": [
|
||||||
|
{"role": "system", "content": "You are a helpful assistant."},
|
||||||
|
{"role": "user", "content": "Hello!"}
|
||||||
|
],
|
||||||
|
"stream": false
|
||||||
|
});
|
||||||
|
|
||||||
|
let result = ctx.make_request("/v1/chat/completions", payload).await;
|
||||||
|
assert!(result.is_ok());
|
||||||
|
|
||||||
|
let response = result.unwrap();
|
||||||
|
assert!(response.get("choices").is_some());
|
||||||
|
assert!(response.get("id").is_some());
|
||||||
|
assert_eq!(
|
||||||
|
response.get("object").and_then(|v| v.as_str()),
|
||||||
|
Some("chat.completion")
|
||||||
|
);
|
||||||
|
|
||||||
|
// Test 2: Chat completion with parameters
|
||||||
|
let payload = json!({
|
||||||
|
"model": "test-model",
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": "Tell me a joke"}
|
||||||
|
],
|
||||||
|
"temperature": 0.8,
|
||||||
|
"max_tokens": 150,
|
||||||
|
"top_p": 0.95,
|
||||||
|
"stream": false
|
||||||
|
});
|
||||||
|
|
||||||
|
let result = ctx.make_request("/v1/chat/completions", payload).await;
|
||||||
|
assert!(result.is_ok());
|
||||||
|
|
||||||
|
ctx.shutdown().await;
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_v1_completions_formats() {
|
||||||
|
let ctx = TestContext::new(vec![MockWorkerConfig {
|
||||||
|
port: 19003,
|
||||||
|
worker_type: WorkerType::Regular,
|
||||||
|
health_status: HealthStatus::Healthy,
|
||||||
|
response_delay_ms: 0,
|
||||||
|
fail_rate: 0.0,
|
||||||
|
}])
|
||||||
|
.await;
|
||||||
|
|
||||||
|
// Test 1: Basic completion
|
||||||
|
let payload = json!({
|
||||||
|
"model": "test-model",
|
||||||
|
"prompt": "Once upon a time",
|
||||||
|
"max_tokens": 50,
|
||||||
|
"stream": false
|
||||||
|
});
|
||||||
|
|
||||||
|
let result = ctx.make_request("/v1/completions", payload).await;
|
||||||
|
assert!(result.is_ok());
|
||||||
|
|
||||||
|
let response = result.unwrap();
|
||||||
|
assert!(response.get("choices").is_some());
|
||||||
|
assert_eq!(
|
||||||
|
response.get("object").and_then(|v| v.as_str()),
|
||||||
|
Some("text_completion")
|
||||||
|
);
|
||||||
|
|
||||||
|
// Test 2: Completion with array prompt
|
||||||
|
let payload = json!({
|
||||||
|
"model": "test-model",
|
||||||
|
"prompt": ["First prompt", "Second prompt"],
|
||||||
|
"temperature": 0.5,
|
||||||
|
"stream": false
|
||||||
|
});
|
||||||
|
|
||||||
|
let result = ctx.make_request("/v1/completions", payload).await;
|
||||||
|
assert!(result.is_ok());
|
||||||
|
|
||||||
|
// Test 3: Completion with logprobs
|
||||||
|
let payload = json!({
|
||||||
|
"model": "test-model",
|
||||||
|
"prompt": "The capital of France is",
|
||||||
|
"max_tokens": 10,
|
||||||
|
"logprobs": 5,
|
||||||
|
"stream": false
|
||||||
|
});
|
||||||
|
|
||||||
|
let result = ctx.make_request("/v1/completions", payload).await;
|
||||||
|
assert!(result.is_ok());
|
||||||
|
|
||||||
|
ctx.shutdown().await;
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_batch_requests() {
|
||||||
|
let ctx = TestContext::new(vec![MockWorkerConfig {
|
||||||
|
port: 19004,
|
||||||
|
worker_type: WorkerType::Regular,
|
||||||
|
health_status: HealthStatus::Healthy,
|
||||||
|
response_delay_ms: 0,
|
||||||
|
fail_rate: 0.0,
|
||||||
|
}])
|
||||||
|
.await;
|
||||||
|
|
||||||
|
// Test batch text generation
|
||||||
|
let payload = json!({
|
||||||
|
"text": ["First text", "Second text", "Third text"],
|
||||||
|
"sampling_params": {
|
||||||
|
"temperature": 0.7,
|
||||||
|
"max_new_tokens": 50
|
||||||
|
},
|
||||||
|
"stream": false
|
||||||
|
});
|
||||||
|
|
||||||
|
let result = ctx.make_request("/generate", payload).await;
|
||||||
|
assert!(result.is_ok());
|
||||||
|
|
||||||
|
// Test batch with input_ids
|
||||||
|
let payload = json!({
|
||||||
|
"input_ids": [[1, 2, 3], [4, 5, 6], [7, 8, 9]],
|
||||||
|
"stream": false
|
||||||
|
});
|
||||||
|
|
||||||
|
let result = ctx.make_request("/generate", payload).await;
|
||||||
|
assert!(result.is_ok());
|
||||||
|
|
||||||
|
ctx.shutdown().await;
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_special_parameters() {
|
||||||
|
let ctx = TestContext::new(vec![MockWorkerConfig {
|
||||||
|
port: 19005,
|
||||||
|
worker_type: WorkerType::Regular,
|
||||||
|
health_status: HealthStatus::Healthy,
|
||||||
|
response_delay_ms: 0,
|
||||||
|
fail_rate: 0.0,
|
||||||
|
}])
|
||||||
|
.await;
|
||||||
|
|
||||||
|
// Test with return_logprob
|
||||||
|
let payload = json!({
|
||||||
|
"text": "Test",
|
||||||
|
"return_logprob": true,
|
||||||
|
"stream": false
|
||||||
|
});
|
||||||
|
|
||||||
|
let result = ctx.make_request("/generate", payload).await;
|
||||||
|
assert!(result.is_ok());
|
||||||
|
|
||||||
|
// Test with json_schema
|
||||||
|
let payload = json!({
|
||||||
|
"text": "Generate JSON",
|
||||||
|
"sampling_params": {
|
||||||
|
"temperature": 0.0,
|
||||||
|
"json_schema": "$$ANY$$"
|
||||||
|
},
|
||||||
|
"stream": false
|
||||||
|
});
|
||||||
|
|
||||||
|
let result = ctx.make_request("/generate", payload).await;
|
||||||
|
assert!(result.is_ok());
|
||||||
|
|
||||||
|
// Test with ignore_eos
|
||||||
|
let payload = json!({
|
||||||
|
"text": "Continue forever",
|
||||||
|
"sampling_params": {
|
||||||
|
"temperature": 0.7,
|
||||||
|
"max_new_tokens": 100,
|
||||||
|
"ignore_eos": true
|
||||||
|
},
|
||||||
|
"stream": false
|
||||||
|
});
|
||||||
|
|
||||||
|
let result = ctx.make_request("/generate", payload).await;
|
||||||
|
assert!(result.is_ok());
|
||||||
|
|
||||||
|
ctx.shutdown().await;
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_error_handling() {
|
||||||
|
let ctx = TestContext::new(vec![MockWorkerConfig {
|
||||||
|
port: 19006,
|
||||||
|
worker_type: WorkerType::Regular,
|
||||||
|
health_status: HealthStatus::Healthy,
|
||||||
|
response_delay_ms: 0,
|
||||||
|
fail_rate: 0.0,
|
||||||
|
}])
|
||||||
|
.await;
|
||||||
|
|
||||||
|
// Test with empty body - should still work with mock worker
|
||||||
|
let payload = json!({});
|
||||||
|
|
||||||
|
let result = ctx.make_request("/generate", payload).await;
|
||||||
|
// Mock worker accepts empty body
|
||||||
|
assert!(result.is_ok());
|
||||||
|
|
||||||
|
ctx.shutdown().await;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,47 +1,28 @@
|
|||||||
mod common;
|
mod common;
|
||||||
|
|
||||||
use actix_web::{http::StatusCode, rt::System, test as actix_test, web, App};
|
|
||||||
use bytes::Bytes;
|
|
||||||
use common::mock_worker::{HealthStatus, MockWorker, MockWorkerConfig, WorkerType};
|
use common::mock_worker::{HealthStatus, MockWorker, MockWorkerConfig, WorkerType};
|
||||||
|
use futures_util::StreamExt;
|
||||||
use reqwest::Client;
|
use reqwest::Client;
|
||||||
use serde_json::json;
|
use serde_json::json;
|
||||||
use sglang_router_rs::config::{PolicyConfig, RouterConfig, RoutingMode};
|
use sglang_router_rs::config::{PolicyConfig, RouterConfig, RoutingMode};
|
||||||
use sglang_router_rs::server::{
|
use sglang_router_rs::routers::{RouterFactory, RouterTrait};
|
||||||
add_worker, generate, list_workers, v1_chat_completions, v1_completions, AppState,
|
use std::sync::Arc;
|
||||||
};
|
|
||||||
use std::time::Instant;
|
|
||||||
|
|
||||||
/// Test context for streaming tests
|
/// Test context that manages mock workers
|
||||||
struct StreamingTestContext {
|
struct TestContext {
|
||||||
workers: Vec<MockWorker>,
|
workers: Vec<MockWorker>,
|
||||||
app_state: web::Data<AppState>,
|
router: Arc<dyn RouterTrait>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl StreamingTestContext {
|
impl TestContext {
|
||||||
async fn new(worker_configs: Vec<MockWorkerConfig>) -> Self {
|
async fn new(worker_configs: Vec<MockWorkerConfig>) -> Self {
|
||||||
let mut workers = Vec::new();
|
let mut config = RouterConfig {
|
||||||
let mut worker_urls = Vec::new();
|
|
||||||
|
|
||||||
// Start mock workers
|
|
||||||
for config in worker_configs {
|
|
||||||
let mut worker = MockWorker::new(config);
|
|
||||||
let url = worker.start().await.unwrap();
|
|
||||||
worker_urls.push(url);
|
|
||||||
workers.push(worker);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Give workers time to start
|
|
||||||
tokio::time::sleep(tokio::time::Duration::from_millis(50)).await;
|
|
||||||
|
|
||||||
// Create router config with empty worker URLs initially
|
|
||||||
// We'll add workers via the /add_worker endpoint
|
|
||||||
let config = RouterConfig {
|
|
||||||
mode: RoutingMode::Regular {
|
mode: RoutingMode::Regular {
|
||||||
worker_urls: vec![],
|
worker_urls: vec![],
|
||||||
},
|
},
|
||||||
policy: PolicyConfig::Random,
|
policy: PolicyConfig::Random,
|
||||||
host: "127.0.0.1".to_string(),
|
host: "127.0.0.1".to_string(),
|
||||||
port: 3003,
|
port: 3004,
|
||||||
max_payload_size: 256 * 1024 * 1024,
|
max_payload_size: 256 * 1024 * 1024,
|
||||||
request_timeout_secs: 600,
|
request_timeout_secs: 600,
|
||||||
worker_startup_timeout_secs: 1,
|
worker_startup_timeout_secs: 1,
|
||||||
@@ -53,530 +34,325 @@ impl StreamingTestContext {
|
|||||||
log_dir: None,
|
log_dir: None,
|
||||||
log_level: None,
|
log_level: None,
|
||||||
request_id_headers: None,
|
request_id_headers: None,
|
||||||
|
max_concurrent_requests: 64,
|
||||||
|
cors_allowed_origins: vec![],
|
||||||
};
|
};
|
||||||
|
|
||||||
let client = Client::builder()
|
let mut workers = Vec::new();
|
||||||
.timeout(std::time::Duration::from_secs(config.request_timeout_secs))
|
let mut worker_urls = Vec::new();
|
||||||
.build()
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
let app_state = AppState::new(config, client).unwrap();
|
for worker_config in worker_configs {
|
||||||
let app_state = web::Data::new(app_state);
|
let mut worker = MockWorker::new(worker_config);
|
||||||
|
let url = worker.start().await.unwrap();
|
||||||
// Add workers via HTTP API
|
worker_urls.push(url);
|
||||||
let app =
|
workers.push(worker);
|
||||||
actix_test::init_service(App::new().app_data(app_state.clone()).service(add_worker))
|
|
||||||
.await;
|
|
||||||
|
|
||||||
for url in &worker_urls {
|
|
||||||
let req = actix_test::TestRequest::post()
|
|
||||||
.uri(&format!("/add_worker?url={}", url))
|
|
||||||
.to_request();
|
|
||||||
let resp = actix_test::call_service(&app, req).await;
|
|
||||||
assert!(resp.status().is_success());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
|
if !workers.is_empty() {
|
||||||
|
tokio::time::sleep(tokio::time::Duration::from_millis(200)).await;
|
||||||
|
}
|
||||||
|
|
||||||
Self { workers, app_state }
|
config.mode = RoutingMode::Regular { worker_urls };
|
||||||
}
|
|
||||||
|
|
||||||
async fn create_app(
|
let router = tokio::task::spawn_blocking(move || RouterFactory::create_router(&config))
|
||||||
&self,
|
.await
|
||||||
) -> impl actix_web::dev::Service<
|
.unwrap()
|
||||||
actix_http::Request,
|
.unwrap();
|
||||||
Response = actix_web::dev::ServiceResponse,
|
let router = Arc::from(router);
|
||||||
Error = actix_web::Error,
|
|
||||||
> {
|
if !workers.is_empty() {
|
||||||
actix_test::init_service(
|
tokio::time::sleep(tokio::time::Duration::from_millis(500)).await;
|
||||||
App::new()
|
}
|
||||||
.app_data(self.app_state.clone())
|
|
||||||
.service(generate)
|
Self { workers, router }
|
||||||
.service(v1_chat_completions)
|
|
||||||
.service(v1_completions)
|
|
||||||
.service(list_workers),
|
|
||||||
)
|
|
||||||
.await
|
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn shutdown(mut self) {
|
async fn shutdown(mut self) {
|
||||||
|
// Small delay to ensure any pending operations complete
|
||||||
|
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
|
||||||
|
|
||||||
for worker in &mut self.workers {
|
for worker in &mut self.workers {
|
||||||
worker.stop().await;
|
worker.stop().await;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Another small delay to ensure cleanup completes
|
||||||
|
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
/// Parse SSE (Server-Sent Events) from response body
|
async fn make_streaming_request(
|
||||||
async fn parse_sse_stream(body: Bytes) -> Vec<serde_json::Value> {
|
&self,
|
||||||
let text = String::from_utf8_lossy(&body);
|
endpoint: &str,
|
||||||
let mut events = Vec::new();
|
body: serde_json::Value,
|
||||||
|
) -> Result<Vec<String>, String> {
|
||||||
|
let client = Client::new();
|
||||||
|
|
||||||
for line in text.lines() {
|
// Get any worker URL for testing
|
||||||
if line.starts_with("data: ") {
|
let worker_urls = self.router.get_worker_urls();
|
||||||
let data = &line[6..];
|
if worker_urls.is_empty() {
|
||||||
if data == "[DONE]" {
|
return Err("No available workers".to_string());
|
||||||
continue;
|
|
||||||
}
|
|
||||||
if let Ok(json) = serde_json::from_str::<serde_json::Value>(data) {
|
|
||||||
events.push(json);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
events
|
let worker_url = &worker_urls[0];
|
||||||
}
|
|
||||||
|
let response = client
|
||||||
#[cfg(test)]
|
.post(&format!("{}{}", worker_url, endpoint))
|
||||||
mod basic_streaming_tests {
|
.json(&body)
|
||||||
use super::*;
|
.send()
|
||||||
|
.await
|
||||||
#[test]
|
.map_err(|e| format!("Request failed: {}", e))?;
|
||||||
fn test_router_uses_mock_workers() {
|
|
||||||
System::new().block_on(async {
|
if !response.status().is_success() {
|
||||||
let ctx = StreamingTestContext::new(vec![MockWorkerConfig {
|
return Err(format!("Request failed with status: {}", response.status()));
|
||||||
port: 19000,
|
}
|
||||||
worker_type: WorkerType::Regular,
|
|
||||||
health_status: HealthStatus::Healthy,
|
// Check if it's a streaming response
|
||||||
response_delay_ms: 0,
|
let content_type = response
|
||||||
fail_rate: 0.0,
|
.headers()
|
||||||
}])
|
.get("content-type")
|
||||||
.await;
|
.and_then(|v| v.to_str().ok())
|
||||||
|
.unwrap_or("");
|
||||||
let app = ctx.create_app().await;
|
|
||||||
|
if !content_type.contains("text/event-stream") {
|
||||||
// Verify workers are registered with the router
|
return Err("Response is not a stream".to_string());
|
||||||
let req = actix_test::TestRequest::get()
|
}
|
||||||
.uri("/list_workers")
|
|
||||||
.to_request();
|
let mut stream = response.bytes_stream();
|
||||||
|
let mut events = Vec::new();
|
||||||
let resp = actix_test::call_service(&app, req).await;
|
|
||||||
assert_eq!(resp.status(), StatusCode::OK);
|
while let Some(chunk) = stream.next().await {
|
||||||
|
if let Ok(bytes) = chunk {
|
||||||
let body: serde_json::Value = actix_test::read_body_json(resp).await;
|
let text = String::from_utf8_lossy(&bytes);
|
||||||
let urls = body["urls"].as_array().unwrap();
|
for line in text.lines() {
|
||||||
assert_eq!(urls.len(), 1);
|
if line.starts_with("data: ") {
|
||||||
assert!(urls[0].as_str().unwrap().contains("19000"));
|
events.push(line[6..].to_string());
|
||||||
|
}
|
||||||
ctx.shutdown().await;
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_generate_streaming() {
|
|
||||||
System::new().block_on(async {
|
|
||||||
let ctx = StreamingTestContext::new(vec![MockWorkerConfig {
|
|
||||||
port: 19001,
|
|
||||||
worker_type: WorkerType::Regular,
|
|
||||||
health_status: HealthStatus::Healthy,
|
|
||||||
response_delay_ms: 0,
|
|
||||||
fail_rate: 0.0,
|
|
||||||
}])
|
|
||||||
.await;
|
|
||||||
|
|
||||||
let app = ctx.create_app().await;
|
|
||||||
|
|
||||||
let payload = json!({
|
|
||||||
"text": "Hello, streaming world!",
|
|
||||||
"stream": true,
|
|
||||||
"max_new_tokens": 50
|
|
||||||
});
|
|
||||||
|
|
||||||
let req = actix_test::TestRequest::post()
|
|
||||||
.uri("/generate")
|
|
||||||
.set_json(&payload)
|
|
||||||
.to_request();
|
|
||||||
|
|
||||||
let resp = actix_test::call_service(&app, req).await;
|
|
||||||
assert_eq!(resp.status(), StatusCode::OK);
|
|
||||||
|
|
||||||
// Check content type
|
|
||||||
let content_type = resp.headers().get("content-type").unwrap();
|
|
||||||
assert_eq!(content_type, "text/event-stream");
|
|
||||||
|
|
||||||
// Read streaming body
|
|
||||||
let body = actix_test::read_body(resp).await;
|
|
||||||
let events = parse_sse_stream(body).await;
|
|
||||||
|
|
||||||
// Verify we got multiple chunks
|
|
||||||
assert!(events.len() > 1);
|
|
||||||
|
|
||||||
// Verify first chunk has text
|
|
||||||
assert!(events[0].get("text").is_some());
|
|
||||||
|
|
||||||
// Verify last chunk has finish_reason in meta_info
|
|
||||||
let last_event = events.last().unwrap();
|
|
||||||
assert!(last_event.get("meta_info").is_some());
|
|
||||||
let meta_info = &last_event["meta_info"];
|
|
||||||
assert!(meta_info.get("finish_reason").is_some());
|
|
||||||
|
|
||||||
ctx.shutdown().await;
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_chat_completion_streaming() {
|
|
||||||
System::new().block_on(async {
|
|
||||||
let ctx = StreamingTestContext::new(vec![MockWorkerConfig {
|
|
||||||
port: 19002,
|
|
||||||
worker_type: WorkerType::Regular,
|
|
||||||
health_status: HealthStatus::Healthy,
|
|
||||||
response_delay_ms: 0,
|
|
||||||
fail_rate: 0.0,
|
|
||||||
}])
|
|
||||||
.await;
|
|
||||||
|
|
||||||
let app = ctx.create_app().await;
|
|
||||||
|
|
||||||
let payload = json!({
|
|
||||||
"model": "test-model",
|
|
||||||
"messages": [
|
|
||||||
{"role": "user", "content": "Hello, streaming!"}
|
|
||||||
],
|
|
||||||
"stream": true
|
|
||||||
});
|
|
||||||
|
|
||||||
let req = actix_test::TestRequest::post()
|
|
||||||
.uri("/v1/chat/completions")
|
|
||||||
.set_json(&payload)
|
|
||||||
.to_request();
|
|
||||||
|
|
||||||
let resp = actix_test::call_service(&app, req).await;
|
|
||||||
assert_eq!(resp.status(), StatusCode::OK);
|
|
||||||
assert_eq!(
|
|
||||||
resp.headers().get("content-type").unwrap(),
|
|
||||||
"text/event-stream"
|
|
||||||
);
|
|
||||||
|
|
||||||
let body = actix_test::read_body(resp).await;
|
|
||||||
let events = parse_sse_stream(body).await;
|
|
||||||
|
|
||||||
// Verify we got streaming events
|
|
||||||
// Note: Mock doesn't provide full OpenAI format, just verify we got chunks
|
|
||||||
assert!(!events.is_empty(), "Should have received streaming events");
|
|
||||||
|
|
||||||
ctx.shutdown().await;
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_completion_streaming() {
|
|
||||||
System::new().block_on(async {
|
|
||||||
let ctx = StreamingTestContext::new(vec![MockWorkerConfig {
|
|
||||||
port: 19003,
|
|
||||||
worker_type: WorkerType::Regular,
|
|
||||||
health_status: HealthStatus::Healthy,
|
|
||||||
response_delay_ms: 0,
|
|
||||||
fail_rate: 0.0,
|
|
||||||
}])
|
|
||||||
.await;
|
|
||||||
|
|
||||||
let app = ctx.create_app().await;
|
|
||||||
|
|
||||||
let payload = json!({
|
|
||||||
"model": "test-model",
|
|
||||||
"prompt": "Once upon a time",
|
|
||||||
"stream": true,
|
|
||||||
"max_tokens": 30
|
|
||||||
});
|
|
||||||
|
|
||||||
let req = actix_test::TestRequest::post()
|
|
||||||
.uri("/v1/completions")
|
|
||||||
.set_json(&payload)
|
|
||||||
.to_request();
|
|
||||||
|
|
||||||
let resp = actix_test::call_service(&app, req).await;
|
|
||||||
assert_eq!(resp.status(), StatusCode::OK);
|
|
||||||
assert_eq!(
|
|
||||||
resp.headers().get("content-type").unwrap(),
|
|
||||||
"text/event-stream"
|
|
||||||
);
|
|
||||||
|
|
||||||
let _body = actix_test::read_body(resp).await;
|
|
||||||
|
|
||||||
ctx.shutdown().await;
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(test)]
|
|
||||||
mod streaming_performance_tests {
|
|
||||||
use super::*;
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_streaming_first_token_latency() {
|
|
||||||
System::new().block_on(async {
|
|
||||||
let ctx = StreamingTestContext::new(vec![MockWorkerConfig {
|
|
||||||
port: 19010,
|
|
||||||
worker_type: WorkerType::Regular,
|
|
||||||
health_status: HealthStatus::Healthy,
|
|
||||||
response_delay_ms: 10, // Small delay to simulate processing
|
|
||||||
fail_rate: 0.0,
|
|
||||||
}])
|
|
||||||
.await;
|
|
||||||
|
|
||||||
let app = ctx.create_app().await;
|
|
||||||
|
|
||||||
let payload = json!({
|
|
||||||
"text": "Measure latency",
|
|
||||||
"stream": true
|
|
||||||
});
|
|
||||||
|
|
||||||
let req = actix_test::TestRequest::post()
|
|
||||||
.uri("/generate")
|
|
||||||
.set_json(&payload)
|
|
||||||
.to_request();
|
|
||||||
|
|
||||||
let start = Instant::now();
|
|
||||||
let resp = actix_test::call_service(&app, req).await;
|
|
||||||
assert_eq!(resp.status(), StatusCode::OK);
|
|
||||||
|
|
||||||
// Note: actix_test framework doesn't provide easy access to streaming chunks.
|
|
||||||
// The ideal solution would be to:
|
|
||||||
// 1. Start the router as a real HTTP server
|
|
||||||
// 2. Use reqwest::Client to make streaming requests
|
|
||||||
// 3. Measure time to first chunk properly
|
|
||||||
//
|
|
||||||
// For now, we verify that streaming responses work correctly,
|
|
||||||
// but cannot accurately measure TTFT with actix_test.
|
|
||||||
let body = actix_test::read_body(resp).await;
|
|
||||||
let total_time = start.elapsed();
|
|
||||||
|
|
||||||
// Verify we got streaming data
|
|
||||||
let events = parse_sse_stream(body).await;
|
|
||||||
assert!(!events.is_empty(), "Should receive streaming events");
|
|
||||||
|
|
||||||
// With mock worker delay of 10ms, total time should still be reasonable
|
|
||||||
assert!(
|
|
||||||
total_time.as_millis() < 1000,
|
|
||||||
"Total response took {}ms",
|
|
||||||
total_time.as_millis()
|
|
||||||
);
|
|
||||||
|
|
||||||
ctx.shutdown().await;
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_concurrent_streaming_requests() {
|
|
||||||
System::new().block_on(async {
|
|
||||||
// Test basic concurrent streaming functionality
|
|
||||||
let ctx = StreamingTestContext::new(vec![
|
|
||||||
MockWorkerConfig {
|
|
||||||
port: 19050,
|
|
||||||
worker_type: WorkerType::Regular,
|
|
||||||
health_status: HealthStatus::Healthy,
|
|
||||||
response_delay_ms: 0,
|
|
||||||
fail_rate: 0.0,
|
|
||||||
},
|
|
||||||
MockWorkerConfig {
|
|
||||||
port: 19051,
|
|
||||||
worker_type: WorkerType::Regular,
|
|
||||||
health_status: HealthStatus::Healthy,
|
|
||||||
response_delay_ms: 0,
|
|
||||||
fail_rate: 0.0,
|
|
||||||
},
|
|
||||||
])
|
|
||||||
.await;
|
|
||||||
|
|
||||||
let app = ctx.create_app().await;
|
|
||||||
|
|
||||||
// Send a moderate number of concurrent requests for unit testing
|
|
||||||
use futures::future::join_all;
|
|
||||||
let mut futures = Vec::new();
|
|
||||||
|
|
||||||
for i in 0..20 {
|
|
||||||
let app_ref = &app;
|
|
||||||
let future = async move {
|
|
||||||
let payload = json!({
|
|
||||||
"text": format!("Concurrent request {}", i),
|
|
||||||
"stream": true,
|
|
||||||
"max_new_tokens": 5
|
|
||||||
});
|
|
||||||
|
|
||||||
let req = actix_test::TestRequest::post()
|
|
||||||
.uri("/generate")
|
|
||||||
.set_json(&payload)
|
|
||||||
.to_request();
|
|
||||||
|
|
||||||
let resp = actix_test::call_service(app_ref, req).await;
|
|
||||||
resp.status() == StatusCode::OK
|
|
||||||
};
|
|
||||||
|
|
||||||
futures.push(future);
|
|
||||||
}
|
|
||||||
|
|
||||||
let results = join_all(futures).await;
|
|
||||||
let successful = results.iter().filter(|&&r| r).count();
|
|
||||||
|
|
||||||
// All requests should succeed in a unit test environment
|
|
||||||
assert_eq!(
|
|
||||||
successful, 20,
|
|
||||||
"Expected all 20 requests to succeed, got {}",
|
|
||||||
successful
|
|
||||||
);
|
|
||||||
|
|
||||||
ctx.shutdown().await;
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
// Note: Extreme load testing has been moved to benches/streaming_load_test.rs
|
|
||||||
// Run with: cargo run --release --bin streaming_load_test 10000 10
|
|
||||||
// Or: cargo bench streaming_load_test
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(test)]
|
|
||||||
mod streaming_error_tests {
|
|
||||||
use super::*;
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_streaming_with_worker_failure() {
|
|
||||||
System::new().block_on(async {
|
|
||||||
let ctx = StreamingTestContext::new(vec![MockWorkerConfig {
|
|
||||||
port: 19020,
|
|
||||||
worker_type: WorkerType::Regular,
|
|
||||||
health_status: HealthStatus::Healthy,
|
|
||||||
response_delay_ms: 0,
|
|
||||||
fail_rate: 1.0, // Always fail
|
|
||||||
}])
|
|
||||||
.await;
|
|
||||||
|
|
||||||
let app = ctx.create_app().await;
|
|
||||||
|
|
||||||
let payload = json!({
|
|
||||||
"text": "This should fail",
|
|
||||||
"stream": true
|
|
||||||
});
|
|
||||||
|
|
||||||
let req = actix_test::TestRequest::post()
|
|
||||||
.uri("/generate")
|
|
||||||
.set_json(&payload)
|
|
||||||
.to_request();
|
|
||||||
|
|
||||||
let resp = actix_test::call_service(&app, req).await;
|
|
||||||
assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR);
|
|
||||||
|
|
||||||
ctx.shutdown().await;
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_streaming_with_invalid_payload() {
|
|
||||||
System::new().block_on(async {
|
|
||||||
let ctx = StreamingTestContext::new(vec![MockWorkerConfig {
|
|
||||||
port: 19021,
|
|
||||||
worker_type: WorkerType::Regular,
|
|
||||||
health_status: HealthStatus::Healthy,
|
|
||||||
response_delay_ms: 0,
|
|
||||||
fail_rate: 0.0,
|
|
||||||
}])
|
|
||||||
.await;
|
|
||||||
|
|
||||||
let app = ctx.create_app().await;
|
|
||||||
|
|
||||||
let payload = json!({
|
|
||||||
// Missing required fields
|
|
||||||
"stream": true
|
|
||||||
});
|
|
||||||
|
|
||||||
let req = actix_test::TestRequest::post()
|
|
||||||
.uri("/generate")
|
|
||||||
.set_json(&payload)
|
|
||||||
.to_request();
|
|
||||||
|
|
||||||
let resp = actix_test::call_service(&app, req).await;
|
|
||||||
// TODO: Router should validate payload and reject requests with missing content fields
|
|
||||||
// Currently, the router accepts requests with no prompt/text/input_ids which is a bug
|
|
||||||
// This should return StatusCode::BAD_REQUEST once proper validation is implemented
|
|
||||||
assert_eq!(resp.status(), StatusCode::OK);
|
|
||||||
|
|
||||||
ctx.shutdown().await;
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(test)]
|
|
||||||
mod streaming_content_tests {
|
|
||||||
use super::*;
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_unicode_streaming() {
|
|
||||||
System::new().block_on(async {
|
|
||||||
let ctx = StreamingTestContext::new(vec![MockWorkerConfig {
|
|
||||||
port: 19030,
|
|
||||||
worker_type: WorkerType::Regular,
|
|
||||||
health_status: HealthStatus::Healthy,
|
|
||||||
response_delay_ms: 0,
|
|
||||||
fail_rate: 0.0,
|
|
||||||
}])
|
|
||||||
.await;
|
|
||||||
|
|
||||||
let app = ctx.create_app().await;
|
|
||||||
|
|
||||||
let payload = json!({
|
|
||||||
"text": "Test Unicode: 你好世界 🌍 émojis",
|
|
||||||
"stream": true
|
|
||||||
});
|
|
||||||
|
|
||||||
let req = actix_test::TestRequest::post()
|
|
||||||
.uri("/generate")
|
|
||||||
.set_json(&payload)
|
|
||||||
.to_request();
|
|
||||||
|
|
||||||
let resp = actix_test::call_service(&app, req).await;
|
|
||||||
assert_eq!(resp.status(), StatusCode::OK);
|
|
||||||
|
|
||||||
let body = actix_test::read_body(resp).await;
|
|
||||||
let events = parse_sse_stream(body).await;
|
|
||||||
|
|
||||||
// Verify events were parsed correctly (Unicode didn't break parsing)
|
|
||||||
assert!(!events.is_empty());
|
|
||||||
|
|
||||||
ctx.shutdown().await;
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_incremental_text_building() {
|
|
||||||
System::new().block_on(async {
|
|
||||||
let ctx = StreamingTestContext::new(vec![MockWorkerConfig {
|
|
||||||
port: 19031,
|
|
||||||
worker_type: WorkerType::Regular,
|
|
||||||
health_status: HealthStatus::Healthy,
|
|
||||||
response_delay_ms: 0,
|
|
||||||
fail_rate: 0.0,
|
|
||||||
}])
|
|
||||||
.await;
|
|
||||||
|
|
||||||
let app = ctx.create_app().await;
|
|
||||||
|
|
||||||
let payload = json!({
|
|
||||||
"text": "Build text incrementally",
|
|
||||||
"stream": true
|
|
||||||
});
|
|
||||||
|
|
||||||
let req = actix_test::TestRequest::post()
|
|
||||||
.uri("/generate")
|
|
||||||
.set_json(&payload)
|
|
||||||
.to_request();
|
|
||||||
|
|
||||||
let resp = actix_test::call_service(&app, req).await;
|
|
||||||
assert_eq!(resp.status(), StatusCode::OK);
|
|
||||||
|
|
||||||
let body = actix_test::read_body(resp).await;
|
|
||||||
let events = parse_sse_stream(body).await;
|
|
||||||
|
|
||||||
// Build complete text from chunks
|
|
||||||
let mut complete_text = String::new();
|
|
||||||
for event in &events {
|
|
||||||
if let Some(text) = event.get("text").and_then(|t| t.as_str()) {
|
|
||||||
complete_text.push_str(text);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Verify we got some text
|
Ok(events)
|
||||||
assert!(!complete_text.is_empty());
|
}
|
||||||
|
}
|
||||||
ctx.shutdown().await;
|
|
||||||
});
|
#[cfg(test)]
|
||||||
|
mod streaming_tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_generate_streaming() {
|
||||||
|
let ctx = TestContext::new(vec![MockWorkerConfig {
|
||||||
|
port: 20001,
|
||||||
|
worker_type: WorkerType::Regular,
|
||||||
|
health_status: HealthStatus::Healthy,
|
||||||
|
response_delay_ms: 10,
|
||||||
|
fail_rate: 0.0,
|
||||||
|
}])
|
||||||
|
.await;
|
||||||
|
|
||||||
|
let payload = json!({
|
||||||
|
"text": "Stream test",
|
||||||
|
"stream": true,
|
||||||
|
"sampling_params": {
|
||||||
|
"temperature": 0.7,
|
||||||
|
"max_new_tokens": 10
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
let result = ctx.make_streaming_request("/generate", payload).await;
|
||||||
|
assert!(result.is_ok());
|
||||||
|
|
||||||
|
let events = result.unwrap();
|
||||||
|
// Should have at least one data chunk and [DONE]
|
||||||
|
assert!(events.len() >= 2);
|
||||||
|
assert_eq!(events.last().unwrap(), "[DONE]");
|
||||||
|
|
||||||
|
ctx.shutdown().await;
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_v1_chat_completions_streaming() {
|
||||||
|
let ctx = TestContext::new(vec![MockWorkerConfig {
|
||||||
|
port: 20002,
|
||||||
|
worker_type: WorkerType::Regular,
|
||||||
|
health_status: HealthStatus::Healthy,
|
||||||
|
response_delay_ms: 10,
|
||||||
|
fail_rate: 0.0,
|
||||||
|
}])
|
||||||
|
.await;
|
||||||
|
|
||||||
|
let payload = json!({
|
||||||
|
"model": "test-model",
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": "Count to 3"}
|
||||||
|
],
|
||||||
|
"stream": true,
|
||||||
|
"max_tokens": 20
|
||||||
|
});
|
||||||
|
|
||||||
|
let result = ctx
|
||||||
|
.make_streaming_request("/v1/chat/completions", payload)
|
||||||
|
.await;
|
||||||
|
assert!(result.is_ok());
|
||||||
|
|
||||||
|
let events = result.unwrap();
|
||||||
|
assert!(events.len() >= 2); // At least one chunk + [DONE]
|
||||||
|
|
||||||
|
// Verify events are valid JSON (except [DONE])
|
||||||
|
for event in &events {
|
||||||
|
if event != "[DONE]" {
|
||||||
|
let parsed: Result<serde_json::Value, _> = serde_json::from_str(event);
|
||||||
|
assert!(parsed.is_ok(), "Invalid JSON in SSE event: {}", event);
|
||||||
|
|
||||||
|
let json = parsed.unwrap();
|
||||||
|
assert_eq!(
|
||||||
|
json.get("object").and_then(|v| v.as_str()),
|
||||||
|
Some("chat.completion.chunk")
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx.shutdown().await;
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_v1_completions_streaming() {
|
||||||
|
let ctx = TestContext::new(vec![MockWorkerConfig {
|
||||||
|
port: 20003,
|
||||||
|
worker_type: WorkerType::Regular,
|
||||||
|
health_status: HealthStatus::Healthy,
|
||||||
|
response_delay_ms: 10,
|
||||||
|
fail_rate: 0.0,
|
||||||
|
}])
|
||||||
|
.await;
|
||||||
|
|
||||||
|
let payload = json!({
|
||||||
|
"model": "test-model",
|
||||||
|
"prompt": "Once upon a time",
|
||||||
|
"stream": true,
|
||||||
|
"max_tokens": 15
|
||||||
|
});
|
||||||
|
|
||||||
|
let result = ctx.make_streaming_request("/v1/completions", payload).await;
|
||||||
|
assert!(result.is_ok());
|
||||||
|
|
||||||
|
let events = result.unwrap();
|
||||||
|
assert!(events.len() >= 2); // At least one chunk + [DONE]
|
||||||
|
|
||||||
|
ctx.shutdown().await;
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_streaming_with_error() {
|
||||||
|
let ctx = TestContext::new(vec![MockWorkerConfig {
|
||||||
|
port: 20004,
|
||||||
|
worker_type: WorkerType::Regular,
|
||||||
|
health_status: HealthStatus::Healthy,
|
||||||
|
response_delay_ms: 0,
|
||||||
|
fail_rate: 1.0, // Always fail
|
||||||
|
}])
|
||||||
|
.await;
|
||||||
|
|
||||||
|
let payload = json!({
|
||||||
|
"text": "This should fail",
|
||||||
|
"stream": true
|
||||||
|
});
|
||||||
|
|
||||||
|
let result = ctx.make_streaming_request("/generate", payload).await;
|
||||||
|
// With fail_rate: 1.0, the request should fail
|
||||||
|
assert!(result.is_err());
|
||||||
|
|
||||||
|
ctx.shutdown().await;
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_streaming_timeouts() {
|
||||||
|
let ctx = TestContext::new(vec![MockWorkerConfig {
|
||||||
|
port: 20005,
|
||||||
|
worker_type: WorkerType::Regular,
|
||||||
|
health_status: HealthStatus::Healthy,
|
||||||
|
response_delay_ms: 100, // Slow response
|
||||||
|
fail_rate: 0.0,
|
||||||
|
}])
|
||||||
|
.await;
|
||||||
|
|
||||||
|
let payload = json!({
|
||||||
|
"text": "Slow stream",
|
||||||
|
"stream": true,
|
||||||
|
"sampling_params": {
|
||||||
|
"max_new_tokens": 5
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
let start = std::time::Instant::now();
|
||||||
|
let result = ctx.make_streaming_request("/generate", payload).await;
|
||||||
|
let elapsed = start.elapsed();
|
||||||
|
|
||||||
|
assert!(result.is_ok());
|
||||||
|
let events = result.unwrap();
|
||||||
|
|
||||||
|
// Should have received multiple chunks over time
|
||||||
|
assert!(!events.is_empty());
|
||||||
|
assert!(elapsed.as_millis() >= 100); // At least one delay
|
||||||
|
|
||||||
|
ctx.shutdown().await;
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_batch_streaming() {
|
||||||
|
let ctx = TestContext::new(vec![MockWorkerConfig {
|
||||||
|
port: 20006,
|
||||||
|
worker_type: WorkerType::Regular,
|
||||||
|
health_status: HealthStatus::Healthy,
|
||||||
|
response_delay_ms: 10,
|
||||||
|
fail_rate: 0.0,
|
||||||
|
}])
|
||||||
|
.await;
|
||||||
|
|
||||||
|
// Batch request with streaming
|
||||||
|
let payload = json!({
|
||||||
|
"text": ["First", "Second", "Third"],
|
||||||
|
"stream": true,
|
||||||
|
"sampling_params": {
|
||||||
|
"max_new_tokens": 5
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
let result = ctx.make_streaming_request("/generate", payload).await;
|
||||||
|
assert!(result.is_ok());
|
||||||
|
|
||||||
|
let events = result.unwrap();
|
||||||
|
// Should have multiple events for batch
|
||||||
|
assert!(events.len() >= 4); // At least 3 responses + [DONE]
|
||||||
|
|
||||||
|
ctx.shutdown().await;
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_sse_format_parsing() {
|
||||||
|
// Test SSE format parsing
|
||||||
|
let parse_sse_chunk = |chunk: &[u8]| -> Vec<String> {
|
||||||
|
let text = String::from_utf8_lossy(chunk);
|
||||||
|
text.lines()
|
||||||
|
.filter(|line| line.starts_with("data: "))
|
||||||
|
.map(|line| line[6..].to_string())
|
||||||
|
.collect()
|
||||||
|
};
|
||||||
|
|
||||||
|
let sse_data =
|
||||||
|
b"data: {\"text\":\"Hello\"}\n\ndata: {\"text\":\" world\"}\n\ndata: [DONE]\n\n";
|
||||||
|
let events = parse_sse_chunk(sse_data);
|
||||||
|
|
||||||
|
assert_eq!(events.len(), 3);
|
||||||
|
assert_eq!(events[0], "{\"text\":\"Hello\"}");
|
||||||
|
assert_eq!(events[1], "{\"text\":\" world\"}");
|
||||||
|
assert_eq!(events[2], "[DONE]");
|
||||||
|
|
||||||
|
// Test with mixed content
|
||||||
|
let mixed = b"event: message\ndata: {\"test\":true}\n\n: comment\ndata: [DONE]\n\n";
|
||||||
|
let events = parse_sse_chunk(mixed);
|
||||||
|
|
||||||
|
assert_eq!(events.len(), 2);
|
||||||
|
assert_eq!(events[0], "{\"test\":true}");
|
||||||
|
assert_eq!(events[1], "[DONE]");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -176,6 +176,8 @@ mod test_pd_routing {
|
|||||||
log_dir: None,
|
log_dir: None,
|
||||||
log_level: None,
|
log_level: None,
|
||||||
request_id_headers: None,
|
request_id_headers: None,
|
||||||
|
max_concurrent_requests: 64,
|
||||||
|
cors_allowed_origins: vec![],
|
||||||
};
|
};
|
||||||
|
|
||||||
// Router creation will fail due to health checks, but config should be valid
|
// Router creation will fail due to health checks, but config should be valid
|
||||||
|
|||||||
Reference in New Issue
Block a user