[2/2] [feature] support openai like classification api in router (#11670)

This commit is contained in:
ybyang
2025-10-19 10:31:08 +08:00
committed by GitHub
parent a7ae61ed77
commit d513ee93ef
14 changed files with 257 additions and 45 deletions

View File

@@ -18,6 +18,7 @@ use crate::{
policies::PolicyRegistry,
protocols::{
chat::ChatCompletionRequest,
classify::ClassifyRequest,
completion::CompletionRequest,
embedding::EmbeddingRequest,
generate::GenerateRequest,
@@ -254,6 +255,15 @@ impl RouterTrait for GrpcPDRouter {
(StatusCode::NOT_IMPLEMENTED).into_response()
}
async fn route_classify(
&self,
_headers: Option<&HeaderMap>,
_body: &ClassifyRequest,
_model_id: Option<&str>,
) -> Response {
(StatusCode::NOT_IMPLEMENTED).into_response()
}
async fn route_embeddings(
&self,
_headers: Option<&HeaderMap>,

View File

@@ -18,6 +18,7 @@ use crate::{
policies::PolicyRegistry,
protocols::{
chat::ChatCompletionRequest,
classify::ClassifyRequest,
completion::CompletionRequest,
embedding::EmbeddingRequest,
generate::GenerateRequest,
@@ -236,6 +237,15 @@ impl RouterTrait for GrpcRouter {
(StatusCode::NOT_IMPLEMENTED).into_response()
}
async fn route_classify(
&self,
_headers: Option<&HeaderMap>,
_body: &ClassifyRequest,
_model_id: Option<&str>,
) -> Response {
(StatusCode::NOT_IMPLEMENTED).into_response()
}
async fn route_embeddings(
&self,
_headers: Option<&HeaderMap>,

View File

@@ -24,6 +24,7 @@ use crate::{
policies::{LoadBalancingPolicy, PolicyRegistry},
protocols::{
chat::{ChatCompletionRequest, ChatMessage, UserMessageContent},
classify::ClassifyRequest,
common::{InputIds, StringOrArray},
completion::CompletionRequest,
embedding::EmbeddingRequest,
@@ -1190,6 +1191,19 @@ impl RouterTrait for PDRouter {
.into_response()
}
async fn route_classify(
&self,
_headers: Option<&HeaderMap>,
_body: &ClassifyRequest,
_model_id: Option<&str>,
) -> Response {
(
StatusCode::NOT_IMPLEMENTED,
"Classify endpoint not implemented for PD router",
)
.into_response()
}
async fn route_embeddings(
&self,
_headers: Option<&HeaderMap>,

View File

@@ -24,6 +24,7 @@ use crate::{
policies::PolicyRegistry,
protocols::{
chat::ChatCompletionRequest,
classify::ClassifyRequest,
common::GenerationRequest,
completion::CompletionRequest,
embedding::EmbeddingRequest,
@@ -749,6 +750,30 @@ impl RouterTrait for Router {
res
}
async fn route_classify(
&self,
headers: Option<&HeaderMap>,
body: &ClassifyRequest,
model_id: Option<&str>,
) -> Response {
// Record classification-specific metrics in addition to general request metrics
let start = Instant::now();
let res = self
.route_typed_request(headers, body, "/v1/classify", model_id)
.await;
// Classification specific metrics
if res.status().is_success() {
RouterMetrics::record_classify_request();
RouterMetrics::record_classify_duration(start.elapsed());
} else {
let error_type = format!("http_{}", res.status().as_u16());
RouterMetrics::record_classify_error(&error_type);
}
res
}
async fn route_rerank(
&self,
headers: Option<&HeaderMap>,

View File

@@ -13,6 +13,7 @@ use serde_json::Value;
use crate::protocols::{
chat::ChatCompletionRequest,
classify::ClassifyRequest,
completion::CompletionRequest,
embedding::EmbeddingRequest,
generate::GenerateRequest,
@@ -125,6 +126,14 @@ pub trait RouterTrait: Send + Sync + Debug {
model_id: Option<&str>,
) -> Response;
/// Route classification requests (OpenAI-compatible /v1/classify)
async fn route_classify(
&self,
headers: Option<&HeaderMap>,
body: &ClassifyRequest,
model_id: Option<&str>,
) -> Response;
async fn route_rerank(
&self,
headers: Option<&HeaderMap>,

View File

@@ -41,10 +41,11 @@ pub(super) async fn create_conversation(
return (
StatusCode::BAD_REQUEST,
Json(json!({
"error": format!(
"metadata cannot have more than {} properties",
MAX_METADATA_PROPERTIES
)
"error":
format!(
"metadata cannot have more than {} properties",
MAX_METADATA_PROPERTIES
)
})),
)
.into_response();
@@ -70,7 +71,9 @@ pub(super) async fn create_conversation(
}
Err(e) => (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": format!("Failed to create conversation: {}", e)})),
Json(json!({
"error": format!("Failed to create conversation: {}", e)
})),
)
.into_response(),
}
@@ -97,7 +100,9 @@ pub(super) async fn get_conversation(
.into_response(),
Err(e) => (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": format!("Failed to get conversation: {}", e)})),
Json(json!({
"error": format!("Failed to get conversation: {}", e)
})),
)
.into_response(),
}
@@ -126,7 +131,9 @@ pub(super) async fn update_conversation(
Err(e) => {
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": format!("Failed to get conversation: {}", e)})),
Json(json!({
"error": format!("Failed to get conversation: {}", e)
})),
)
.into_response();
}
@@ -174,10 +181,11 @@ pub(super) async fn update_conversation(
return (
StatusCode::BAD_REQUEST,
Json(json!({
"error": format!(
"metadata cannot have more than {} properties",
MAX_METADATA_PROPERTIES
)
"error":
format!(
"metadata cannot have more than {} properties",
MAX_METADATA_PROPERTIES
)
})),
)
.into_response();
@@ -204,7 +212,9 @@ pub(super) async fn update_conversation(
.into_response(),
Err(e) => (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": format!("Failed to update conversation: {}", e)})),
Json(json!({
"error": format!("Failed to update conversation: {}", e)
})),
)
.into_response(),
}
@@ -232,7 +242,9 @@ pub(super) async fn delete_conversation(
Err(e) => {
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": format!("Failed to get conversation: {}", e)})),
Json(json!({
"error": format!("Failed to get conversation: {}", e)
})),
)
.into_response();
}
@@ -256,7 +268,9 @@ pub(super) async fn delete_conversation(
}
Err(e) => (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": format!("Failed to delete conversation: {}", e)})),
Json(json!({
"error": format!("Failed to delete conversation: {}", e)
})),
)
.into_response(),
}
@@ -286,7 +300,9 @@ pub(super) async fn list_conversation_items(
Err(e) => {
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": format!("Failed to get conversation: {}", e)})),
Json(json!({
"error": format!("Failed to get conversation: {}", e)
})),
)
.into_response();
}
@@ -346,7 +362,7 @@ pub(super) async fn list_conversation_items(
}
Err(e) => (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": format!("Failed to list items: {}", e)})),
Json(json!({ "error": format!("Failed to list items: {}", e) })),
)
.into_response(),
}
@@ -417,7 +433,9 @@ pub(super) async fn create_conversation_items(
Err(e) => {
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": format!("Failed to get conversation: {}", e)})),
Json(json!({
"error": format!("Failed to get conversation: {}", e)
})),
)
.into_response();
}
@@ -476,14 +494,18 @@ pub(super) async fn create_conversation_items(
Ok(None) => {
return (
StatusCode::NOT_FOUND,
Json(json!({"error": format!("Referenced item '{}' not found", ref_id)})),
Json(json!({
"error": format!("Referenced item '{}' not found", ref_id)
})),
)
.into_response();
}
Err(e) => {
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": format!("Failed to get referenced item: {}", e)})),
Json(json!({
"error": format!("Failed to get referenced item: {}", e)
})),
)
.into_response();
}
@@ -517,7 +539,9 @@ pub(super) async fn create_conversation_items(
Err(e) => {
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": format!("Failed to check item link: {}", e)})),
Json(json!({
"error": format!("Failed to check item link: {}", e)
})),
)
.into_response();
}
@@ -553,7 +577,7 @@ pub(super) async fn create_conversation_items(
Err(e) => {
return (
StatusCode::BAD_REQUEST,
Json(json!({"error": format!("Invalid item: {}", e)})),
Json(json!({ "error": format!("Invalid item: {}", e) })),
)
.into_response();
}
@@ -570,7 +594,7 @@ pub(super) async fn create_conversation_items(
Err(e) => {
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": format!("Failed to create item: {}", e)})),
Json(json!({ "error": format!("Failed to create item: {}", e) })),
)
.into_response();
}
@@ -579,7 +603,9 @@ pub(super) async fn create_conversation_items(
Err(e) => {
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": format!("Failed to check item existence: {}", e)})),
Json(json!({
"error": format!("Failed to check item existence: {}", e)
})),
)
.into_response();
}
@@ -593,7 +619,7 @@ pub(super) async fn create_conversation_items(
Err(e) => {
return (
StatusCode::BAD_REQUEST,
Json(json!({"error": format!("Invalid item: {}", e)})),
Json(json!({ "error": format!("Invalid item: {}", e) })),
)
.into_response();
}
@@ -610,7 +636,7 @@ pub(super) async fn create_conversation_items(
Err(e) => {
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": format!("Failed to create item: {}", e)})),
Json(json!({ "error": format!("Failed to create item: {}", e) })),
)
.into_response();
}
@@ -678,7 +704,9 @@ pub(super) async fn get_conversation_item(
Err(e) => {
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": format!("Failed to get conversation: {}", e)})),
Json(json!({
"error": format!("Failed to get conversation: {}", e)
})),
)
.into_response();
}
@@ -693,7 +721,9 @@ pub(super) async fn get_conversation_item(
Err(e) => {
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": format!("Failed to check item link: {}", e)})),
Json(json!({
"error": format!("Failed to check item link: {}", e)
})),
)
.into_response();
}
@@ -721,7 +751,7 @@ pub(super) async fn get_conversation_item(
.into_response(),
Err(e) => (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": format!("Failed to get item: {}", e)})),
Json(json!({ "error": format!("Failed to get item: {}", e) })),
)
.into_response(),
}
@@ -753,7 +783,9 @@ pub(super) async fn delete_conversation_item(
Err(e) => {
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": format!("Failed to get conversation: {}", e)})),
Json(json!({
"error": format!("Failed to get conversation: {}", e)
})),
)
.into_response();
}
@@ -773,7 +805,7 @@ pub(super) async fn delete_conversation_item(
}
Err(e) => (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": format!("Failed to delete item: {}", e)})),
Json(json!({ "error": format!("Failed to delete item: {}", e) })),
)
.into_response(),
}

View File

@@ -156,7 +156,7 @@ pub(super) fn patch_streaming_response_json(
// Attach conversation id for client response if present (final aggregated JSON)
if let Some(conv_id) = original_body.conversation.clone() {
obj.insert("conversation".to_string(), json!({"id": conv_id}));
obj.insert("conversation".to_string(), json!({ "id": conv_id }));
}
}
}
@@ -234,7 +234,7 @@ pub(super) fn rewrite_streaming_block(
// Attach conversation id into streaming event response content with ordering
if let Some(conv_id) = original_body.conversation.clone() {
response_obj.insert("conversation".to_string(), json!({"id": conv_id}));
response_obj.insert("conversation".to_string(), json!({ "id": conv_id }));
changed = true;
}
}

View File

@@ -42,6 +42,7 @@ use crate::{
},
protocols::{
chat::ChatCompletionRequest,
classify::ClassifyRequest,
completion::CompletionRequest,
embedding::EmbeddingRequest,
generate::GenerateRequest,
@@ -828,7 +829,7 @@ impl crate::routers::RouterTrait for OpenAIRouter {
.into_response(),
Err(e) => (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({"error": format!("Failed to get response: {}", e)})),
Json(json!({ "error": format!("Failed to get response: {}", e) })),
)
.into_response(),
}
@@ -882,6 +883,15 @@ impl crate::routers::RouterTrait for OpenAIRouter {
(StatusCode::NOT_IMPLEMENTED, "Rerank not supported").into_response()
}
async fn route_classify(
&self,
_headers: Option<&HeaderMap>,
_body: &ClassifyRequest,
_model_id: Option<&str>,
) -> Response {
(StatusCode::NOT_IMPLEMENTED, "Classify not supported").into_response()
}
async fn create_conversation(&self, _headers: Option<&HeaderMap>, body: &Value) -> Response {
create_conversation(&self.conversation_storage, body.clone()).await
}

View File

@@ -22,6 +22,7 @@ use crate::{
core::{WorkerRegistry, WorkerType},
protocols::{
chat::ChatCompletionRequest,
classify::ClassifyRequest,
completion::CompletionRequest,
embedding::EmbeddingRequest,
generate::GenerateRequest,
@@ -329,10 +330,7 @@ impl RouterTrait for RouterManager {
} else {
(
StatusCode::OK,
serde_json::json!({
"models": models
})
.to_string(),
serde_json::json!({ "models": models }).to_string(),
)
.into_response()
}
@@ -517,6 +515,25 @@ impl RouterTrait for RouterManager {
}
}
async fn route_classify(
&self,
headers: Option<&HeaderMap>,
body: &ClassifyRequest,
model_id: Option<&str>,
) -> Response {
let router = self.select_router_for_request(headers, Some(&body.model));
if let Some(router) = router {
router.route_classify(headers, body, model_id).await
} else {
(
StatusCode::NOT_FOUND,
format!("Model '{}' not found or no router available", body.model),
)
.into_response()
}
}
fn router_type(&self) -> &'static str {
"manager"
}