[router]: Add Embedding routing logic (#10129)

Signed-off-by: Jintao Zhang <zhangjintao9020@gmail.com>
Co-authored-by: Waël Boukhobza <wawa_wael@live.fr>
This commit is contained in:
Jintao Zhang
2025-09-15 09:44:35 +08:00
committed by GitHub
parent dcee42c200
commit f9ee6ae17a
17 changed files with 452 additions and 69 deletions

View File

@@ -143,6 +143,18 @@ pub fn init_metrics() {
"Generate request duration"
);
// Embedding request specific metrics
describe_counter!("sgl_router_embeddings_total", "Total embedding requests");
describe_histogram!(
"sgl_router_embeddings_duration_seconds",
"Embedding request duration"
);
describe_counter!(
"sgl_router_embeddings_errors_total",
"Embedding request errors"
);
describe_gauge!("sgl_router_embeddings_queue_size", "Embedding queue size");
// Running requests gauge for cache-aware policy
describe_gauge!(
"sgl_router_running_requests",
@@ -440,6 +452,27 @@ impl RouterMetrics {
histogram!("sgl_router_generate_duration_seconds").record(duration.as_secs_f64());
}
// Embeddings metrics
pub fn record_embeddings_request() {
counter!("sgl_router_embeddings_total").increment(1);
}
pub fn record_embeddings_duration(duration: Duration) {
histogram!("sgl_router_embeddings_duration_seconds").record(duration.as_secs_f64());
}
pub fn record_embeddings_error(error_type: &str) {
counter!(
"sgl_router_embeddings_errors_total",
"error_type" => error_type.to_string()
)
.increment(1);
}
pub fn set_embeddings_queue_size(size: usize) {
gauge!("sgl_router_embeddings_queue_size").set(size as f64);
}
// Running requests for cache-aware policy
pub fn set_running_requests(worker: &str, count: usize) {
gauge!("sgl_router_running_requests",

View File

@@ -3,6 +3,7 @@ use axum::{
response::IntoResponse, response::Response,
};
use rand::Rng;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use std::time::Duration;
use std::time::Instant;
@@ -13,6 +14,7 @@ use tracing::{debug, error, field::Empty, info, info_span, warn, Span};
pub use crate::core::token_bucket::TokenBucket;
use crate::metrics::RouterMetrics;
use crate::server::AppState;
/// Generate OpenAI-compatible request ID based on endpoint
@@ -441,6 +443,11 @@ pub async fn concurrency_limit_middleware(
request: Request<axum::body::Body>,
next: Next,
) -> Response {
// Static counter for embeddings queue size
static EMBEDDINGS_QUEUE_SIZE: AtomicU64 = AtomicU64::new(0);
// Identify if this is an embeddings request based on path
let is_embeddings = request.uri().path().contains("/v1/embeddings");
let token_bucket = app_state.context.rate_limiter.clone();
// Try to acquire token immediately
@@ -468,10 +475,23 @@ pub async fn concurrency_limit_middleware(
// Try to send to queue
match queue_tx.try_send(queued) {
Ok(_) => {
// On successful enqueue, update embeddings queue gauge if applicable
if is_embeddings {
let new_val = EMBEDDINGS_QUEUE_SIZE.fetch_add(1, Ordering::Relaxed) + 1;
RouterMetrics::set_embeddings_queue_size(new_val as usize);
}
// Wait for token from queue processor
match permit_rx.await {
Ok(Ok(())) => {
debug!("Acquired token from queue");
// Dequeue for embeddings
if is_embeddings {
let new_val =
EMBEDDINGS_QUEUE_SIZE.fetch_sub(1, Ordering::Relaxed) - 1;
RouterMetrics::set_embeddings_queue_size(new_val as usize);
}
let response = next.run(request).await;
// Return the token to the bucket
@@ -481,10 +501,22 @@ pub async fn concurrency_limit_middleware(
}
Ok(Err(status)) => {
warn!("Queue returned error status: {}", status);
// Dequeue for embeddings on error
if is_embeddings {
let new_val =
EMBEDDINGS_QUEUE_SIZE.fetch_sub(1, Ordering::Relaxed) - 1;
RouterMetrics::set_embeddings_queue_size(new_val as usize);
}
status.into_response()
}
Err(_) => {
error!("Queue response channel closed");
// Dequeue for embeddings on channel error
if is_embeddings {
let new_val =
EMBEDDINGS_QUEUE_SIZE.fetch_sub(1, Ordering::Relaxed) - 1;
RouterMetrics::set_embeddings_queue_size(new_val as usize);
}
StatusCode::INTERNAL_SERVER_ERROR.into_response()
}
}

View File

@@ -41,7 +41,10 @@ use std::collections::HashMap;
// 6. **SGLANG SPEC - RERANK API**
// - Request/Response structures
//
// 7. **COMMON**
// 7. **OPENAI SPEC - Embeddings API**
// - Request structures
//
// 8. **COMMON**
// - GenerationRequest trait
// - StringOrArray & LoRAPath types
// - Helper functions
@@ -2013,6 +2016,61 @@ impl RerankResponse {
}
}
// ==================================================================
// = OPENAI SPEC - Embeddings API =
// ==================================================================
/// Embeddings request compatible with OpenAI API
/// We intentionally keep fields flexible to pass through to workers.
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct EmbeddingRequest {
/// ID of the model to use
pub model: String,
/// Input can be a string, array of strings, tokens, or batch inputs
pub input: serde_json::Value,
/// Optional encoding format (e.g., "float", "base64")
#[serde(skip_serializing_if = "Option::is_none")]
pub encoding_format: Option<String>,
/// Optional user identifier
#[serde(skip_serializing_if = "Option::is_none")]
pub user: Option<String>,
/// Optional number of dimensions for the embedding
#[serde(skip_serializing_if = "Option::is_none")]
pub dimensions: Option<u32>,
/// SGLang extension: request id for tracking
#[serde(skip_serializing_if = "Option::is_none")]
pub rid: Option<String>,
}
impl GenerationRequest for EmbeddingRequest {
fn is_stream(&self) -> bool {
// Embeddings are non-streaming
false
}
fn get_model(&self) -> Option<&str> {
Some(&self.model)
}
fn extract_text_for_routing(&self) -> String {
// Best effort: extract text content for routing decisions
match &self.input {
serde_json::Value::String(s) => s.clone(),
serde_json::Value::Array(arr) => arr
.iter()
.filter_map(|v| v.as_str())
.collect::<Vec<_>>()
.join(" "),
_ => String::new(),
}
}
}
// ==================================================================
// = COMMON =
// ==================================================================
@@ -2715,4 +2773,102 @@ mod tests {
assert_eq!(deserialized.results.len(), 2);
assert_eq!(deserialized.model, response.model);
}
// ==================================================================
// = EMBEDDINGS REQUEST TESTS =
// ==================================================================
#[test]
fn test_embedding_request_serialization_string_input() {
let req = EmbeddingRequest {
model: "test-emb".to_string(),
input: serde_json::Value::String("hello".to_string()),
encoding_format: Some("float".to_string()),
user: Some("user-1".to_string()),
dimensions: Some(128),
rid: Some("rid-123".to_string()),
};
let serialized = serde_json::to_string(&req).unwrap();
let deserialized: EmbeddingRequest = serde_json::from_str(&serialized).unwrap();
assert_eq!(deserialized.model, req.model);
assert_eq!(deserialized.input, req.input);
assert_eq!(deserialized.encoding_format, req.encoding_format);
assert_eq!(deserialized.user, req.user);
assert_eq!(deserialized.dimensions, req.dimensions);
assert_eq!(deserialized.rid, req.rid);
}
#[test]
fn test_embedding_request_serialization_array_input() {
let req = EmbeddingRequest {
model: "test-emb".to_string(),
input: serde_json::json!(["a", "b", "c"]),
encoding_format: None,
user: None,
dimensions: None,
rid: None,
};
let serialized = serde_json::to_string(&req).unwrap();
let de: EmbeddingRequest = serde_json::from_str(&serialized).unwrap();
assert_eq!(de.model, req.model);
assert_eq!(de.input, req.input);
}
#[test]
fn test_embedding_generation_request_trait_string() {
let req = EmbeddingRequest {
model: "emb-model".to_string(),
input: serde_json::Value::String("hello".to_string()),
encoding_format: None,
user: None,
dimensions: None,
rid: None,
};
assert!(!req.is_stream());
assert_eq!(req.get_model(), Some("emb-model"));
assert_eq!(req.extract_text_for_routing(), "hello");
}
#[test]
fn test_embedding_generation_request_trait_array() {
let req = EmbeddingRequest {
model: "emb-model".to_string(),
input: serde_json::json!(["hello", "world"]),
encoding_format: None,
user: None,
dimensions: None,
rid: None,
};
assert_eq!(req.extract_text_for_routing(), "hello world");
}
#[test]
fn test_embedding_generation_request_trait_non_text() {
let req = EmbeddingRequest {
model: "emb-model".to_string(),
input: serde_json::json!({"tokens": [1, 2, 3]}),
encoding_format: None,
user: None,
dimensions: None,
rid: None,
};
assert_eq!(req.extract_text_for_routing(), "");
}
#[test]
fn test_embedding_generation_request_trait_mixed_array_ignores_nested() {
let req = EmbeddingRequest {
model: "emb-model".to_string(),
input: serde_json::json!(["a", ["b", "c"], 123, {"k": "v"}]),
encoding_format: None,
user: None,
dimensions: None,
rid: None,
};
// Only top-level string elements are extracted
assert_eq!(req.extract_text_for_routing(), "a");
}
}

View File

@@ -309,7 +309,12 @@ impl RouterTrait for GrpcPDRouter {
(StatusCode::NOT_IMPLEMENTED).into_response()
}
async fn route_embeddings(&self, _headers: Option<&HeaderMap>, _body: Body) -> Response {
async fn route_embeddings(
&self,
_headers: Option<&HeaderMap>,
_body: &crate::protocols::spec::EmbeddingRequest,
_model_id: Option<&str>,
) -> Response {
(StatusCode::NOT_IMPLEMENTED).into_response()
}

View File

@@ -242,7 +242,12 @@ impl RouterTrait for GrpcRouter {
(StatusCode::NOT_IMPLEMENTED).into_response()
}
async fn route_embeddings(&self, _headers: Option<&HeaderMap>, _body: Body) -> Response {
async fn route_embeddings(
&self,
_headers: Option<&HeaderMap>,
_body: &crate::protocols::spec::EmbeddingRequest,
_model_id: Option<&str>,
) -> Response {
(StatusCode::NOT_IMPLEMENTED).into_response()
}

View File

@@ -395,7 +395,12 @@ impl super::super::RouterTrait for OpenAIRouter {
}
}
async fn route_embeddings(&self, _headers: Option<&HeaderMap>, _body: Body) -> Response {
async fn route_embeddings(
&self,
_headers: Option<&HeaderMap>,
_body: &crate::protocols::spec::EmbeddingRequest,
_model_id: Option<&str>,
) -> Response {
(
StatusCode::NOT_IMPLEMENTED,
"Embeddings endpoint not implemented for OpenAI backend",

View File

@@ -1938,8 +1938,17 @@ impl RouterTrait for PDRouter {
.into_response()
}
async fn route_embeddings(&self, _headers: Option<&HeaderMap>, _body: Body) -> Response {
todo!()
async fn route_embeddings(
&self,
_headers: Option<&HeaderMap>,
_body: &crate::protocols::spec::EmbeddingRequest,
_model_id: Option<&str>,
) -> Response {
(
StatusCode::NOT_IMPLEMENTED,
"Embeddings endpoint not implemented for PD router",
)
.into_response()
}
async fn route_rerank(

View File

@@ -6,8 +6,8 @@ use crate::core::{
use crate::metrics::RouterMetrics;
use crate::policies::{LoadBalancingPolicy, PolicyRegistry};
use crate::protocols::spec::{
ChatCompletionRequest, CompletionRequest, GenerateRequest, GenerationRequest, RerankRequest,
RerankResponse, RerankResult, ResponsesRequest,
ChatCompletionRequest, CompletionRequest, EmbeddingRequest, GenerateRequest, GenerationRequest,
RerankRequest, RerankResponse, RerankResult, ResponsesRequest,
};
use crate::routers::header_utils;
use crate::routers::{RouterTrait, WorkerManagement};
@@ -1430,8 +1430,28 @@ impl RouterTrait for Router {
self.route_post_empty_request(headers, &endpoint).await
}
async fn route_embeddings(&self, _headers: Option<&HeaderMap>, _body: Body) -> Response {
todo!()
async fn route_embeddings(
&self,
headers: Option<&HeaderMap>,
body: &EmbeddingRequest,
model_id: Option<&str>,
) -> Response {
// Record embeddings-specific metrics in addition to general request metrics
let start = Instant::now();
let res = self
.route_typed_request(headers, body, "/v1/embeddings", model_id)
.await;
// Embedding specific metrics
if res.status().is_success() {
RouterMetrics::record_embeddings_request();
RouterMetrics::record_embeddings_duration(start.elapsed());
} else {
let error_type = format!("http_{}", res.status().as_u16());
RouterMetrics::record_embeddings_error(&error_type);
}
res
}
async fn route_rerank(

View File

@@ -10,7 +10,8 @@ use axum::{
use std::fmt::Debug;
use crate::protocols::spec::{
ChatCompletionRequest, CompletionRequest, GenerateRequest, RerankRequest, ResponsesRequest,
ChatCompletionRequest, CompletionRequest, EmbeddingRequest, GenerateRequest, RerankRequest,
ResponsesRequest,
};
pub mod factory;
@@ -123,7 +124,13 @@ pub trait RouterTrait: Send + Sync + Debug + WorkerManagement {
.into_response()
}
async fn route_embeddings(&self, headers: Option<&HeaderMap>, body: Body) -> Response;
/// Route embedding requests (OpenAI-compatible /v1/embeddings)
async fn route_embeddings(
&self,
headers: Option<&HeaderMap>,
body: &EmbeddingRequest,
model_id: Option<&str>,
) -> Response;
async fn route_rerank(
&self,

View File

@@ -7,7 +7,8 @@
use crate::config::RouterConfig;
use crate::core::{CircuitBreakerConfig, Worker, WorkerFactory, WorkerRegistry};
use crate::protocols::spec::{
ChatCompletionRequest, CompletionRequest, GenerateRequest, RerankRequest, ResponsesRequest,
ChatCompletionRequest, CompletionRequest, EmbeddingRequest, GenerateRequest, RerankRequest,
ResponsesRequest,
};
use crate::protocols::worker_spec::{
ServerInfo, WorkerApiResponse, WorkerConfigRequest, WorkerErrorResponse, WorkerInfo,
@@ -665,22 +666,6 @@ impl RouterTrait for RouterManager {
.into_response()
}
async fn get_response(&self, _headers: Option<&HeaderMap>, _response_id: &str) -> Response {
(
StatusCode::NOT_IMPLEMENTED,
"responses api not yet implemented in inference gateway mode",
)
.into_response()
}
async fn cancel_response(&self, _headers: Option<&HeaderMap>, _response_id: &str) -> Response {
(
StatusCode::NOT_IMPLEMENTED,
"responses api not yet implemented in inference gateway mode",
)
.into_response()
}
async fn delete_response(&self, _headers: Option<&HeaderMap>, _response_id: &str) -> Response {
(
StatusCode::NOT_IMPLEMENTED,
@@ -701,17 +686,51 @@ impl RouterTrait for RouterManager {
.into_response()
}
/// Route embeddings request
async fn route_embeddings(&self, headers: Option<&HeaderMap>, body: Body) -> Response {
// Try to select a router based on headers
async fn get_response(&self, headers: Option<&HeaderMap>, response_id: &str) -> Response {
let router = self.select_router_for_request(headers, None);
if let Some(router) = router {
router.route_embeddings(headers, body).await
router.get_response(headers, response_id).await
} else {
(
StatusCode::NOT_FOUND,
"No router available for embeddings request",
format!("No router available to get response '{}'", response_id),
)
.into_response()
}
}
async fn cancel_response(&self, headers: Option<&HeaderMap>, response_id: &str) -> Response {
let router = self.select_router_for_request(headers, None);
if let Some(router) = router {
router.cancel_response(headers, response_id).await
} else {
(
StatusCode::NOT_FOUND,
format!("No router available to cancel response '{}'", response_id),
)
.into_response()
}
}
/// Route embeddings request
async fn route_embeddings(
&self,
headers: Option<&HeaderMap>,
body: &EmbeddingRequest,
_model_id: Option<&str>,
) -> Response {
// Select router based on headers and model
let router = self.select_router_for_request(headers, Some(&body.model));
if let Some(router) = router {
router
.route_embeddings(headers, body, Some(&body.model))
.await
} else {
// Return 404 when the specified model is not found
(
StatusCode::NOT_FOUND,
format!("Model '{}' not found or no router available", body.model),
)
.into_response()
}

View File

@@ -5,8 +5,8 @@ use crate::metrics::{self, PrometheusConfig};
use crate::middleware::TokenBucket;
use crate::policies::PolicyRegistry;
use crate::protocols::spec::{
ChatCompletionRequest, CompletionRequest, GenerateRequest, RerankRequest, ResponsesRequest,
V1RerankReqInput,
ChatCompletionRequest, CompletionRequest, EmbeddingRequest, GenerateRequest, RerankRequest,
ResponsesRequest, V1RerankReqInput,
};
use crate::protocols::worker_spec::{WorkerApiResponse, WorkerConfigRequest, WorkerErrorResponse};
use crate::reasoning_parser::ParserFactory;
@@ -208,6 +208,17 @@ async fn v1_responses(
.await
}
async fn v1_embeddings(
State(state): State<Arc<AppState>>,
headers: http::HeaderMap,
Json(body): Json<EmbeddingRequest>,
) -> Response {
state
.router
.route_embeddings(Some(&headers), &body, None)
.await
}
async fn v1_responses_get(
State(state): State<Arc<AppState>>,
Path(response_id): Path<String>,
@@ -465,6 +476,7 @@ pub fn build_app(
.route("/rerank", post(rerank))
.route("/v1/rerank", post(v1_rerank))
.route("/v1/responses", post(v1_responses))
.route("/v1/embeddings", post(v1_embeddings))
.route("/v1/responses/{response_id}", get(v1_responses_get))
.route(
"/v1/responses/{response_id}/cancel",