diff --git a/sgl-router/src/metrics.rs b/sgl-router/src/metrics.rs index 14dfb4ca5..b063a46b2 100644 --- a/sgl-router/src/metrics.rs +++ b/sgl-router/src/metrics.rs @@ -480,6 +480,26 @@ impl RouterMetrics { gauge!("sgl_router_embeddings_queue_size").set(size as f64); } + pub fn record_classify_request() { + counter!("sgl_router_classify_total").increment(1); + } + + pub fn record_classify_duration(duration: Duration) { + histogram!("sgl_router_classify_duration_seconds").record(duration.as_secs_f64()); + } + + pub fn record_classify_error(error_type: &str) { + counter!( + "sgl_router_classify_errors_total", + "error_type" => error_type.to_string() + ) + .increment(1); + } + + pub fn set_classify_queue_size(size: usize) { + gauge!("sgl_router_classify_queue_size").set(size as f64); + } + pub fn set_running_requests(worker: &str, count: usize) { gauge!("sgl_router_running_requests", "worker" => worker.to_string() diff --git a/sgl-router/src/protocols/classify.rs b/sgl-router/src/protocols/classify.rs new file mode 100644 index 000000000..fc7e8b871 --- /dev/null +++ b/sgl-router/src/protocols/classify.rs @@ -0,0 +1,57 @@ +use serde::{Deserialize, Serialize}; +use serde_json::Value; + +use super::common::GenerationRequest; + +// ============================================================================ +// Embedding API +// ============================================================================ + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct ClassifyRequest { + /// ID of the model to use + pub model: String, + + /// Input can be a string, array of strings, tokens, or batch inputs + pub input: Value, + + /// Optional encoding format (e.g., "float", "base64") + #[serde(skip_serializing_if = "Option::is_none")] + pub encoding_format: Option, + + /// Optional user identifier + #[serde(skip_serializing_if = "Option::is_none")] + pub user: Option, + + /// Optional number of dimensions for the embedding + #[serde(skip_serializing_if = "Option::is_none")] + pub dimensions: Option, + + /// SGLang extension: request id for tracking + #[serde(skip_serializing_if = "Option::is_none")] + pub rid: Option, +} + +impl GenerationRequest for ClassifyRequest { + fn is_stream(&self) -> bool { + // Embeddings are non-streaming + false + } + + fn get_model(&self) -> Option<&str> { + Some(&self.model) + } + + fn extract_text_for_routing(&self) -> String { + // Best effort: extract text content for routing decisions + match &self.input { + Value::String(s) => s.clone(), + Value::Array(arr) => arr + .iter() + .filter_map(|v| v.as_str()) + .collect::>() + .join(" "), + _ => String::new(), + } + } +} diff --git a/sgl-router/src/protocols/mod.rs b/sgl-router/src/protocols/mod.rs index 5ba6b1893..d9d2526b9 100644 --- a/sgl-router/src/protocols/mod.rs +++ b/sgl-router/src/protocols/mod.rs @@ -2,6 +2,7 @@ // This module provides a structured approach to handling different API protocols pub mod chat; +pub mod classify; pub mod common; pub mod completion; pub mod embedding; diff --git a/sgl-router/src/routers/grpc/pd_router.rs b/sgl-router/src/routers/grpc/pd_router.rs index 243df61b5..7adeb0fb7 100644 --- a/sgl-router/src/routers/grpc/pd_router.rs +++ b/sgl-router/src/routers/grpc/pd_router.rs @@ -18,6 +18,7 @@ use crate::{ policies::PolicyRegistry, protocols::{ chat::ChatCompletionRequest, + classify::ClassifyRequest, completion::CompletionRequest, embedding::EmbeddingRequest, generate::GenerateRequest, @@ -254,6 +255,15 @@ impl RouterTrait for GrpcPDRouter { (StatusCode::NOT_IMPLEMENTED).into_response() } + async fn route_classify( + &self, + _headers: Option<&HeaderMap>, + _body: &ClassifyRequest, + _model_id: Option<&str>, + ) -> Response { + (StatusCode::NOT_IMPLEMENTED).into_response() + } + async fn route_embeddings( &self, _headers: Option<&HeaderMap>, diff --git a/sgl-router/src/routers/grpc/router.rs b/sgl-router/src/routers/grpc/router.rs index ceed70e62..ff7009506 100644 --- a/sgl-router/src/routers/grpc/router.rs +++ b/sgl-router/src/routers/grpc/router.rs @@ -18,6 +18,7 @@ use crate::{ policies::PolicyRegistry, protocols::{ chat::ChatCompletionRequest, + classify::ClassifyRequest, completion::CompletionRequest, embedding::EmbeddingRequest, generate::GenerateRequest, @@ -236,6 +237,15 @@ impl RouterTrait for GrpcRouter { (StatusCode::NOT_IMPLEMENTED).into_response() } + async fn route_classify( + &self, + _headers: Option<&HeaderMap>, + _body: &ClassifyRequest, + _model_id: Option<&str>, + ) -> Response { + (StatusCode::NOT_IMPLEMENTED).into_response() + } + async fn route_embeddings( &self, _headers: Option<&HeaderMap>, diff --git a/sgl-router/src/routers/http/pd_router.rs b/sgl-router/src/routers/http/pd_router.rs index 078bb5080..c6554c162 100644 --- a/sgl-router/src/routers/http/pd_router.rs +++ b/sgl-router/src/routers/http/pd_router.rs @@ -24,6 +24,7 @@ use crate::{ policies::{LoadBalancingPolicy, PolicyRegistry}, protocols::{ chat::{ChatCompletionRequest, ChatMessage, UserMessageContent}, + classify::ClassifyRequest, common::{InputIds, StringOrArray}, completion::CompletionRequest, embedding::EmbeddingRequest, @@ -1190,6 +1191,19 @@ impl RouterTrait for PDRouter { .into_response() } + async fn route_classify( + &self, + _headers: Option<&HeaderMap>, + _body: &ClassifyRequest, + _model_id: Option<&str>, + ) -> Response { + ( + StatusCode::NOT_IMPLEMENTED, + "Classify endpoint not implemented for PD router", + ) + .into_response() + } + async fn route_embeddings( &self, _headers: Option<&HeaderMap>, diff --git a/sgl-router/src/routers/http/router.rs b/sgl-router/src/routers/http/router.rs index 911859a0d..ca607be40 100644 --- a/sgl-router/src/routers/http/router.rs +++ b/sgl-router/src/routers/http/router.rs @@ -24,6 +24,7 @@ use crate::{ policies::PolicyRegistry, protocols::{ chat::ChatCompletionRequest, + classify::ClassifyRequest, common::GenerationRequest, completion::CompletionRequest, embedding::EmbeddingRequest, @@ -749,6 +750,30 @@ impl RouterTrait for Router { res } + async fn route_classify( + &self, + headers: Option<&HeaderMap>, + body: &ClassifyRequest, + model_id: Option<&str>, + ) -> Response { + // Record classification-specific metrics in addition to general request metrics + let start = Instant::now(); + let res = self + .route_typed_request(headers, body, "/v1/classify", model_id) + .await; + + // Classification specific metrics + if res.status().is_success() { + RouterMetrics::record_classify_request(); + RouterMetrics::record_classify_duration(start.elapsed()); + } else { + let error_type = format!("http_{}", res.status().as_u16()); + RouterMetrics::record_classify_error(&error_type); + } + + res + } + async fn route_rerank( &self, headers: Option<&HeaderMap>, diff --git a/sgl-router/src/routers/mod.rs b/sgl-router/src/routers/mod.rs index b034605dd..aaa963468 100644 --- a/sgl-router/src/routers/mod.rs +++ b/sgl-router/src/routers/mod.rs @@ -13,6 +13,7 @@ use serde_json::Value; use crate::protocols::{ chat::ChatCompletionRequest, + classify::ClassifyRequest, completion::CompletionRequest, embedding::EmbeddingRequest, generate::GenerateRequest, @@ -125,6 +126,14 @@ pub trait RouterTrait: Send + Sync + Debug { model_id: Option<&str>, ) -> Response; + /// Route classification requests (OpenAI-compatible /v1/classify) + async fn route_classify( + &self, + headers: Option<&HeaderMap>, + body: &ClassifyRequest, + model_id: Option<&str>, + ) -> Response; + async fn route_rerank( &self, headers: Option<&HeaderMap>, diff --git a/sgl-router/src/routers/openai/conversations.rs b/sgl-router/src/routers/openai/conversations.rs index 8d9b8c0a1..09e0a56d4 100644 --- a/sgl-router/src/routers/openai/conversations.rs +++ b/sgl-router/src/routers/openai/conversations.rs @@ -41,10 +41,11 @@ pub(super) async fn create_conversation( return ( StatusCode::BAD_REQUEST, Json(json!({ - "error": format!( - "metadata cannot have more than {} properties", - MAX_METADATA_PROPERTIES - ) + "error": + format!( + "metadata cannot have more than {} properties", + MAX_METADATA_PROPERTIES + ) })), ) .into_response(); @@ -70,7 +71,9 @@ pub(super) async fn create_conversation( } Err(e) => ( StatusCode::INTERNAL_SERVER_ERROR, - Json(json!({"error": format!("Failed to create conversation: {}", e)})), + Json(json!({ + "error": format!("Failed to create conversation: {}", e) + })), ) .into_response(), } @@ -97,7 +100,9 @@ pub(super) async fn get_conversation( .into_response(), Err(e) => ( StatusCode::INTERNAL_SERVER_ERROR, - Json(json!({"error": format!("Failed to get conversation: {}", e)})), + Json(json!({ + "error": format!("Failed to get conversation: {}", e) + })), ) .into_response(), } @@ -126,7 +131,9 @@ pub(super) async fn update_conversation( Err(e) => { return ( StatusCode::INTERNAL_SERVER_ERROR, - Json(json!({"error": format!("Failed to get conversation: {}", e)})), + Json(json!({ + "error": format!("Failed to get conversation: {}", e) + })), ) .into_response(); } @@ -174,10 +181,11 @@ pub(super) async fn update_conversation( return ( StatusCode::BAD_REQUEST, Json(json!({ - "error": format!( - "metadata cannot have more than {} properties", - MAX_METADATA_PROPERTIES - ) + "error": + format!( + "metadata cannot have more than {} properties", + MAX_METADATA_PROPERTIES + ) })), ) .into_response(); @@ -204,7 +212,9 @@ pub(super) async fn update_conversation( .into_response(), Err(e) => ( StatusCode::INTERNAL_SERVER_ERROR, - Json(json!({"error": format!("Failed to update conversation: {}", e)})), + Json(json!({ + "error": format!("Failed to update conversation: {}", e) + })), ) .into_response(), } @@ -232,7 +242,9 @@ pub(super) async fn delete_conversation( Err(e) => { return ( StatusCode::INTERNAL_SERVER_ERROR, - Json(json!({"error": format!("Failed to get conversation: {}", e)})), + Json(json!({ + "error": format!("Failed to get conversation: {}", e) + })), ) .into_response(); } @@ -256,7 +268,9 @@ pub(super) async fn delete_conversation( } Err(e) => ( StatusCode::INTERNAL_SERVER_ERROR, - Json(json!({"error": format!("Failed to delete conversation: {}", e)})), + Json(json!({ + "error": format!("Failed to delete conversation: {}", e) + })), ) .into_response(), } @@ -286,7 +300,9 @@ pub(super) async fn list_conversation_items( Err(e) => { return ( StatusCode::INTERNAL_SERVER_ERROR, - Json(json!({"error": format!("Failed to get conversation: {}", e)})), + Json(json!({ + "error": format!("Failed to get conversation: {}", e) + })), ) .into_response(); } @@ -346,7 +362,7 @@ pub(super) async fn list_conversation_items( } Err(e) => ( StatusCode::INTERNAL_SERVER_ERROR, - Json(json!({"error": format!("Failed to list items: {}", e)})), + Json(json!({ "error": format!("Failed to list items: {}", e) })), ) .into_response(), } @@ -417,7 +433,9 @@ pub(super) async fn create_conversation_items( Err(e) => { return ( StatusCode::INTERNAL_SERVER_ERROR, - Json(json!({"error": format!("Failed to get conversation: {}", e)})), + Json(json!({ + "error": format!("Failed to get conversation: {}", e) + })), ) .into_response(); } @@ -476,14 +494,18 @@ pub(super) async fn create_conversation_items( Ok(None) => { return ( StatusCode::NOT_FOUND, - Json(json!({"error": format!("Referenced item '{}' not found", ref_id)})), + Json(json!({ + "error": format!("Referenced item '{}' not found", ref_id) + })), ) .into_response(); } Err(e) => { return ( StatusCode::INTERNAL_SERVER_ERROR, - Json(json!({"error": format!("Failed to get referenced item: {}", e)})), + Json(json!({ + "error": format!("Failed to get referenced item: {}", e) + })), ) .into_response(); } @@ -517,7 +539,9 @@ pub(super) async fn create_conversation_items( Err(e) => { return ( StatusCode::INTERNAL_SERVER_ERROR, - Json(json!({"error": format!("Failed to check item link: {}", e)})), + Json(json!({ + "error": format!("Failed to check item link: {}", e) + })), ) .into_response(); } @@ -553,7 +577,7 @@ pub(super) async fn create_conversation_items( Err(e) => { return ( StatusCode::BAD_REQUEST, - Json(json!({"error": format!("Invalid item: {}", e)})), + Json(json!({ "error": format!("Invalid item: {}", e) })), ) .into_response(); } @@ -570,7 +594,7 @@ pub(super) async fn create_conversation_items( Err(e) => { return ( StatusCode::INTERNAL_SERVER_ERROR, - Json(json!({"error": format!("Failed to create item: {}", e)})), + Json(json!({ "error": format!("Failed to create item: {}", e) })), ) .into_response(); } @@ -579,7 +603,9 @@ pub(super) async fn create_conversation_items( Err(e) => { return ( StatusCode::INTERNAL_SERVER_ERROR, - Json(json!({"error": format!("Failed to check item existence: {}", e)})), + Json(json!({ + "error": format!("Failed to check item existence: {}", e) + })), ) .into_response(); } @@ -593,7 +619,7 @@ pub(super) async fn create_conversation_items( Err(e) => { return ( StatusCode::BAD_REQUEST, - Json(json!({"error": format!("Invalid item: {}", e)})), + Json(json!({ "error": format!("Invalid item: {}", e) })), ) .into_response(); } @@ -610,7 +636,7 @@ pub(super) async fn create_conversation_items( Err(e) => { return ( StatusCode::INTERNAL_SERVER_ERROR, - Json(json!({"error": format!("Failed to create item: {}", e)})), + Json(json!({ "error": format!("Failed to create item: {}", e) })), ) .into_response(); } @@ -678,7 +704,9 @@ pub(super) async fn get_conversation_item( Err(e) => { return ( StatusCode::INTERNAL_SERVER_ERROR, - Json(json!({"error": format!("Failed to get conversation: {}", e)})), + Json(json!({ + "error": format!("Failed to get conversation: {}", e) + })), ) .into_response(); } @@ -693,7 +721,9 @@ pub(super) async fn get_conversation_item( Err(e) => { return ( StatusCode::INTERNAL_SERVER_ERROR, - Json(json!({"error": format!("Failed to check item link: {}", e)})), + Json(json!({ + "error": format!("Failed to check item link: {}", e) + })), ) .into_response(); } @@ -721,7 +751,7 @@ pub(super) async fn get_conversation_item( .into_response(), Err(e) => ( StatusCode::INTERNAL_SERVER_ERROR, - Json(json!({"error": format!("Failed to get item: {}", e)})), + Json(json!({ "error": format!("Failed to get item: {}", e) })), ) .into_response(), } @@ -753,7 +783,9 @@ pub(super) async fn delete_conversation_item( Err(e) => { return ( StatusCode::INTERNAL_SERVER_ERROR, - Json(json!({"error": format!("Failed to get conversation: {}", e)})), + Json(json!({ + "error": format!("Failed to get conversation: {}", e) + })), ) .into_response(); } @@ -773,7 +805,7 @@ pub(super) async fn delete_conversation_item( } Err(e) => ( StatusCode::INTERNAL_SERVER_ERROR, - Json(json!({"error": format!("Failed to delete item: {}", e)})), + Json(json!({ "error": format!("Failed to delete item: {}", e) })), ) .into_response(), } diff --git a/sgl-router/src/routers/openai/responses.rs b/sgl-router/src/routers/openai/responses.rs index 3931ac019..27d0a4289 100644 --- a/sgl-router/src/routers/openai/responses.rs +++ b/sgl-router/src/routers/openai/responses.rs @@ -156,7 +156,7 @@ pub(super) fn patch_streaming_response_json( // Attach conversation id for client response if present (final aggregated JSON) if let Some(conv_id) = original_body.conversation.clone() { - obj.insert("conversation".to_string(), json!({"id": conv_id})); + obj.insert("conversation".to_string(), json!({ "id": conv_id })); } } } @@ -234,7 +234,7 @@ pub(super) fn rewrite_streaming_block( // Attach conversation id into streaming event response content with ordering if let Some(conv_id) = original_body.conversation.clone() { - response_obj.insert("conversation".to_string(), json!({"id": conv_id})); + response_obj.insert("conversation".to_string(), json!({ "id": conv_id })); changed = true; } } diff --git a/sgl-router/src/routers/openai/router.rs b/sgl-router/src/routers/openai/router.rs index 3008929ea..ac94660cd 100644 --- a/sgl-router/src/routers/openai/router.rs +++ b/sgl-router/src/routers/openai/router.rs @@ -42,6 +42,7 @@ use crate::{ }, protocols::{ chat::ChatCompletionRequest, + classify::ClassifyRequest, completion::CompletionRequest, embedding::EmbeddingRequest, generate::GenerateRequest, @@ -828,7 +829,7 @@ impl crate::routers::RouterTrait for OpenAIRouter { .into_response(), Err(e) => ( StatusCode::INTERNAL_SERVER_ERROR, - Json(json!({"error": format!("Failed to get response: {}", e)})), + Json(json!({ "error": format!("Failed to get response: {}", e) })), ) .into_response(), } @@ -882,6 +883,15 @@ impl crate::routers::RouterTrait for OpenAIRouter { (StatusCode::NOT_IMPLEMENTED, "Rerank not supported").into_response() } + async fn route_classify( + &self, + _headers: Option<&HeaderMap>, + _body: &ClassifyRequest, + _model_id: Option<&str>, + ) -> Response { + (StatusCode::NOT_IMPLEMENTED, "Classify not supported").into_response() + } + async fn create_conversation(&self, _headers: Option<&HeaderMap>, body: &Value) -> Response { create_conversation(&self.conversation_storage, body.clone()).await } diff --git a/sgl-router/src/routers/router_manager.rs b/sgl-router/src/routers/router_manager.rs index 9cce72269..f249281e6 100644 --- a/sgl-router/src/routers/router_manager.rs +++ b/sgl-router/src/routers/router_manager.rs @@ -22,6 +22,7 @@ use crate::{ core::{WorkerRegistry, WorkerType}, protocols::{ chat::ChatCompletionRequest, + classify::ClassifyRequest, completion::CompletionRequest, embedding::EmbeddingRequest, generate::GenerateRequest, @@ -329,10 +330,7 @@ impl RouterTrait for RouterManager { } else { ( StatusCode::OK, - serde_json::json!({ - "models": models - }) - .to_string(), + serde_json::json!({ "models": models }).to_string(), ) .into_response() } @@ -517,6 +515,25 @@ impl RouterTrait for RouterManager { } } + async fn route_classify( + &self, + headers: Option<&HeaderMap>, + body: &ClassifyRequest, + model_id: Option<&str>, + ) -> Response { + let router = self.select_router_for_request(headers, Some(&body.model)); + + if let Some(router) = router { + router.route_classify(headers, body, model_id).await + } else { + ( + StatusCode::NOT_FOUND, + format!("Model '{}' not found or no router available", body.model), + ) + .into_response() + } + } + fn router_type(&self) -> &'static str { "manager" } diff --git a/sgl-router/src/server.rs b/sgl-router/src/server.rs index 762c3fc28..8be88e995 100644 --- a/sgl-router/src/server.rs +++ b/sgl-router/src/server.rs @@ -37,6 +37,7 @@ use crate::{ policies::PolicyRegistry, protocols::{ chat::ChatCompletionRequest, + classify::ClassifyRequest, completion::CompletionRequest, embedding::EmbeddingRequest, generate::GenerateRequest, @@ -270,6 +271,17 @@ async fn v1_embeddings( .await } +async fn v1_classify( + State(state): State>, + headers: http::HeaderMap, + Json(body): Json, +) -> Response { + state + .router + .route_classify(Some(&headers), &body, None) + .await +} + async fn v1_responses_get( State(state): State>, Path(response_id): Path, @@ -534,13 +546,7 @@ async fn get_loads(State(state): State>, _req: Request) -> Respons }) .collect(); - ( - StatusCode::OK, - Json(json!({ - "workers": loads - })), - ) - .into_response() + (StatusCode::OK, Json(json!({ "workers": loads }))).into_response() } async fn create_worker( @@ -707,6 +713,7 @@ pub fn build_app( .route("/v1/rerank", post(v1_rerank)) .route("/v1/responses", post(v1_responses)) .route("/v1/embeddings", post(v1_embeddings)) + .route("/v1/classify", post(v1_classify)) .route("/v1/responses/{response_id}", get(v1_responses_get)) .route( "/v1/responses/{response_id}/cancel", diff --git a/sgl-router/tests/responses_api_test.rs b/sgl-router/tests/responses_api_test.rs index 589b0ce5d..31db26f62 100644 --- a/sgl-router/tests/responses_api_test.rs +++ b/sgl-router/tests/responses_api_test.rs @@ -1617,7 +1617,7 @@ async fn test_conversation_items_max_limit() { "content": [{"type": "input_text", "text": format!("Message {}", i)}] })); } - let create_items = serde_json::json!({"items": items}); + let create_items = serde_json::json!({ "items": items }); let items_resp = router .create_conversation_items(None, conv_id, &create_items)