[2/2] [feature] support openai like classification api in router (#11670)
This commit is contained in:
@@ -480,6 +480,26 @@ impl RouterMetrics {
|
||||
gauge!("sgl_router_embeddings_queue_size").set(size as f64);
|
||||
}
|
||||
|
||||
pub fn record_classify_request() {
|
||||
counter!("sgl_router_classify_total").increment(1);
|
||||
}
|
||||
|
||||
pub fn record_classify_duration(duration: Duration) {
|
||||
histogram!("sgl_router_classify_duration_seconds").record(duration.as_secs_f64());
|
||||
}
|
||||
|
||||
pub fn record_classify_error(error_type: &str) {
|
||||
counter!(
|
||||
"sgl_router_classify_errors_total",
|
||||
"error_type" => error_type.to_string()
|
||||
)
|
||||
.increment(1);
|
||||
}
|
||||
|
||||
pub fn set_classify_queue_size(size: usize) {
|
||||
gauge!("sgl_router_classify_queue_size").set(size as f64);
|
||||
}
|
||||
|
||||
pub fn set_running_requests(worker: &str, count: usize) {
|
||||
gauge!("sgl_router_running_requests",
|
||||
"worker" => worker.to_string()
|
||||
|
||||
57
sgl-router/src/protocols/classify.rs
Normal file
57
sgl-router/src/protocols/classify.rs
Normal file
@@ -0,0 +1,57 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
|
||||
use super::common::GenerationRequest;
|
||||
|
||||
// ============================================================================
|
||||
// Embedding API
|
||||
// ============================================================================
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct ClassifyRequest {
|
||||
/// ID of the model to use
|
||||
pub model: String,
|
||||
|
||||
/// Input can be a string, array of strings, tokens, or batch inputs
|
||||
pub input: Value,
|
||||
|
||||
/// Optional encoding format (e.g., "float", "base64")
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub encoding_format: Option<String>,
|
||||
|
||||
/// Optional user identifier
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub user: Option<String>,
|
||||
|
||||
/// Optional number of dimensions for the embedding
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub dimensions: Option<u32>,
|
||||
|
||||
/// SGLang extension: request id for tracking
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub rid: Option<String>,
|
||||
}
|
||||
|
||||
impl GenerationRequest for ClassifyRequest {
|
||||
fn is_stream(&self) -> bool {
|
||||
// Embeddings are non-streaming
|
||||
false
|
||||
}
|
||||
|
||||
fn get_model(&self) -> Option<&str> {
|
||||
Some(&self.model)
|
||||
}
|
||||
|
||||
fn extract_text_for_routing(&self) -> String {
|
||||
// Best effort: extract text content for routing decisions
|
||||
match &self.input {
|
||||
Value::String(s) => s.clone(),
|
||||
Value::Array(arr) => arr
|
||||
.iter()
|
||||
.filter_map(|v| v.as_str())
|
||||
.collect::<Vec<_>>()
|
||||
.join(" "),
|
||||
_ => String::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -2,6 +2,7 @@
|
||||
// This module provides a structured approach to handling different API protocols
|
||||
|
||||
pub mod chat;
|
||||
pub mod classify;
|
||||
pub mod common;
|
||||
pub mod completion;
|
||||
pub mod embedding;
|
||||
|
||||
@@ -18,6 +18,7 @@ use crate::{
|
||||
policies::PolicyRegistry,
|
||||
protocols::{
|
||||
chat::ChatCompletionRequest,
|
||||
classify::ClassifyRequest,
|
||||
completion::CompletionRequest,
|
||||
embedding::EmbeddingRequest,
|
||||
generate::GenerateRequest,
|
||||
@@ -254,6 +255,15 @@ impl RouterTrait for GrpcPDRouter {
|
||||
(StatusCode::NOT_IMPLEMENTED).into_response()
|
||||
}
|
||||
|
||||
async fn route_classify(
|
||||
&self,
|
||||
_headers: Option<&HeaderMap>,
|
||||
_body: &ClassifyRequest,
|
||||
_model_id: Option<&str>,
|
||||
) -> Response {
|
||||
(StatusCode::NOT_IMPLEMENTED).into_response()
|
||||
}
|
||||
|
||||
async fn route_embeddings(
|
||||
&self,
|
||||
_headers: Option<&HeaderMap>,
|
||||
|
||||
@@ -18,6 +18,7 @@ use crate::{
|
||||
policies::PolicyRegistry,
|
||||
protocols::{
|
||||
chat::ChatCompletionRequest,
|
||||
classify::ClassifyRequest,
|
||||
completion::CompletionRequest,
|
||||
embedding::EmbeddingRequest,
|
||||
generate::GenerateRequest,
|
||||
@@ -236,6 +237,15 @@ impl RouterTrait for GrpcRouter {
|
||||
(StatusCode::NOT_IMPLEMENTED).into_response()
|
||||
}
|
||||
|
||||
async fn route_classify(
|
||||
&self,
|
||||
_headers: Option<&HeaderMap>,
|
||||
_body: &ClassifyRequest,
|
||||
_model_id: Option<&str>,
|
||||
) -> Response {
|
||||
(StatusCode::NOT_IMPLEMENTED).into_response()
|
||||
}
|
||||
|
||||
async fn route_embeddings(
|
||||
&self,
|
||||
_headers: Option<&HeaderMap>,
|
||||
|
||||
@@ -24,6 +24,7 @@ use crate::{
|
||||
policies::{LoadBalancingPolicy, PolicyRegistry},
|
||||
protocols::{
|
||||
chat::{ChatCompletionRequest, ChatMessage, UserMessageContent},
|
||||
classify::ClassifyRequest,
|
||||
common::{InputIds, StringOrArray},
|
||||
completion::CompletionRequest,
|
||||
embedding::EmbeddingRequest,
|
||||
@@ -1190,6 +1191,19 @@ impl RouterTrait for PDRouter {
|
||||
.into_response()
|
||||
}
|
||||
|
||||
async fn route_classify(
|
||||
&self,
|
||||
_headers: Option<&HeaderMap>,
|
||||
_body: &ClassifyRequest,
|
||||
_model_id: Option<&str>,
|
||||
) -> Response {
|
||||
(
|
||||
StatusCode::NOT_IMPLEMENTED,
|
||||
"Classify endpoint not implemented for PD router",
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
|
||||
async fn route_embeddings(
|
||||
&self,
|
||||
_headers: Option<&HeaderMap>,
|
||||
|
||||
@@ -24,6 +24,7 @@ use crate::{
|
||||
policies::PolicyRegistry,
|
||||
protocols::{
|
||||
chat::ChatCompletionRequest,
|
||||
classify::ClassifyRequest,
|
||||
common::GenerationRequest,
|
||||
completion::CompletionRequest,
|
||||
embedding::EmbeddingRequest,
|
||||
@@ -749,6 +750,30 @@ impl RouterTrait for Router {
|
||||
res
|
||||
}
|
||||
|
||||
async fn route_classify(
|
||||
&self,
|
||||
headers: Option<&HeaderMap>,
|
||||
body: &ClassifyRequest,
|
||||
model_id: Option<&str>,
|
||||
) -> Response {
|
||||
// Record classification-specific metrics in addition to general request metrics
|
||||
let start = Instant::now();
|
||||
let res = self
|
||||
.route_typed_request(headers, body, "/v1/classify", model_id)
|
||||
.await;
|
||||
|
||||
// Classification specific metrics
|
||||
if res.status().is_success() {
|
||||
RouterMetrics::record_classify_request();
|
||||
RouterMetrics::record_classify_duration(start.elapsed());
|
||||
} else {
|
||||
let error_type = format!("http_{}", res.status().as_u16());
|
||||
RouterMetrics::record_classify_error(&error_type);
|
||||
}
|
||||
|
||||
res
|
||||
}
|
||||
|
||||
async fn route_rerank(
|
||||
&self,
|
||||
headers: Option<&HeaderMap>,
|
||||
|
||||
@@ -13,6 +13,7 @@ use serde_json::Value;
|
||||
|
||||
use crate::protocols::{
|
||||
chat::ChatCompletionRequest,
|
||||
classify::ClassifyRequest,
|
||||
completion::CompletionRequest,
|
||||
embedding::EmbeddingRequest,
|
||||
generate::GenerateRequest,
|
||||
@@ -125,6 +126,14 @@ pub trait RouterTrait: Send + Sync + Debug {
|
||||
model_id: Option<&str>,
|
||||
) -> Response;
|
||||
|
||||
/// Route classification requests (OpenAI-compatible /v1/classify)
|
||||
async fn route_classify(
|
||||
&self,
|
||||
headers: Option<&HeaderMap>,
|
||||
body: &ClassifyRequest,
|
||||
model_id: Option<&str>,
|
||||
) -> Response;
|
||||
|
||||
async fn route_rerank(
|
||||
&self,
|
||||
headers: Option<&HeaderMap>,
|
||||
|
||||
@@ -41,10 +41,11 @@ pub(super) async fn create_conversation(
|
||||
return (
|
||||
StatusCode::BAD_REQUEST,
|
||||
Json(json!({
|
||||
"error": format!(
|
||||
"metadata cannot have more than {} properties",
|
||||
MAX_METADATA_PROPERTIES
|
||||
)
|
||||
"error":
|
||||
format!(
|
||||
"metadata cannot have more than {} properties",
|
||||
MAX_METADATA_PROPERTIES
|
||||
)
|
||||
})),
|
||||
)
|
||||
.into_response();
|
||||
@@ -70,7 +71,9 @@ pub(super) async fn create_conversation(
|
||||
}
|
||||
Err(e) => (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({"error": format!("Failed to create conversation: {}", e)})),
|
||||
Json(json!({
|
||||
"error": format!("Failed to create conversation: {}", e)
|
||||
})),
|
||||
)
|
||||
.into_response(),
|
||||
}
|
||||
@@ -97,7 +100,9 @@ pub(super) async fn get_conversation(
|
||||
.into_response(),
|
||||
Err(e) => (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({"error": format!("Failed to get conversation: {}", e)})),
|
||||
Json(json!({
|
||||
"error": format!("Failed to get conversation: {}", e)
|
||||
})),
|
||||
)
|
||||
.into_response(),
|
||||
}
|
||||
@@ -126,7 +131,9 @@ pub(super) async fn update_conversation(
|
||||
Err(e) => {
|
||||
return (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({"error": format!("Failed to get conversation: {}", e)})),
|
||||
Json(json!({
|
||||
"error": format!("Failed to get conversation: {}", e)
|
||||
})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
@@ -174,10 +181,11 @@ pub(super) async fn update_conversation(
|
||||
return (
|
||||
StatusCode::BAD_REQUEST,
|
||||
Json(json!({
|
||||
"error": format!(
|
||||
"metadata cannot have more than {} properties",
|
||||
MAX_METADATA_PROPERTIES
|
||||
)
|
||||
"error":
|
||||
format!(
|
||||
"metadata cannot have more than {} properties",
|
||||
MAX_METADATA_PROPERTIES
|
||||
)
|
||||
})),
|
||||
)
|
||||
.into_response();
|
||||
@@ -204,7 +212,9 @@ pub(super) async fn update_conversation(
|
||||
.into_response(),
|
||||
Err(e) => (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({"error": format!("Failed to update conversation: {}", e)})),
|
||||
Json(json!({
|
||||
"error": format!("Failed to update conversation: {}", e)
|
||||
})),
|
||||
)
|
||||
.into_response(),
|
||||
}
|
||||
@@ -232,7 +242,9 @@ pub(super) async fn delete_conversation(
|
||||
Err(e) => {
|
||||
return (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({"error": format!("Failed to get conversation: {}", e)})),
|
||||
Json(json!({
|
||||
"error": format!("Failed to get conversation: {}", e)
|
||||
})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
@@ -256,7 +268,9 @@ pub(super) async fn delete_conversation(
|
||||
}
|
||||
Err(e) => (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({"error": format!("Failed to delete conversation: {}", e)})),
|
||||
Json(json!({
|
||||
"error": format!("Failed to delete conversation: {}", e)
|
||||
})),
|
||||
)
|
||||
.into_response(),
|
||||
}
|
||||
@@ -286,7 +300,9 @@ pub(super) async fn list_conversation_items(
|
||||
Err(e) => {
|
||||
return (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({"error": format!("Failed to get conversation: {}", e)})),
|
||||
Json(json!({
|
||||
"error": format!("Failed to get conversation: {}", e)
|
||||
})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
@@ -346,7 +362,7 @@ pub(super) async fn list_conversation_items(
|
||||
}
|
||||
Err(e) => (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({"error": format!("Failed to list items: {}", e)})),
|
||||
Json(json!({ "error": format!("Failed to list items: {}", e) })),
|
||||
)
|
||||
.into_response(),
|
||||
}
|
||||
@@ -417,7 +433,9 @@ pub(super) async fn create_conversation_items(
|
||||
Err(e) => {
|
||||
return (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({"error": format!("Failed to get conversation: {}", e)})),
|
||||
Json(json!({
|
||||
"error": format!("Failed to get conversation: {}", e)
|
||||
})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
@@ -476,14 +494,18 @@ pub(super) async fn create_conversation_items(
|
||||
Ok(None) => {
|
||||
return (
|
||||
StatusCode::NOT_FOUND,
|
||||
Json(json!({"error": format!("Referenced item '{}' not found", ref_id)})),
|
||||
Json(json!({
|
||||
"error": format!("Referenced item '{}' not found", ref_id)
|
||||
})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
Err(e) => {
|
||||
return (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({"error": format!("Failed to get referenced item: {}", e)})),
|
||||
Json(json!({
|
||||
"error": format!("Failed to get referenced item: {}", e)
|
||||
})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
@@ -517,7 +539,9 @@ pub(super) async fn create_conversation_items(
|
||||
Err(e) => {
|
||||
return (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({"error": format!("Failed to check item link: {}", e)})),
|
||||
Json(json!({
|
||||
"error": format!("Failed to check item link: {}", e)
|
||||
})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
@@ -553,7 +577,7 @@ pub(super) async fn create_conversation_items(
|
||||
Err(e) => {
|
||||
return (
|
||||
StatusCode::BAD_REQUEST,
|
||||
Json(json!({"error": format!("Invalid item: {}", e)})),
|
||||
Json(json!({ "error": format!("Invalid item: {}", e) })),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
@@ -570,7 +594,7 @@ pub(super) async fn create_conversation_items(
|
||||
Err(e) => {
|
||||
return (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({"error": format!("Failed to create item: {}", e)})),
|
||||
Json(json!({ "error": format!("Failed to create item: {}", e) })),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
@@ -579,7 +603,9 @@ pub(super) async fn create_conversation_items(
|
||||
Err(e) => {
|
||||
return (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({"error": format!("Failed to check item existence: {}", e)})),
|
||||
Json(json!({
|
||||
"error": format!("Failed to check item existence: {}", e)
|
||||
})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
@@ -593,7 +619,7 @@ pub(super) async fn create_conversation_items(
|
||||
Err(e) => {
|
||||
return (
|
||||
StatusCode::BAD_REQUEST,
|
||||
Json(json!({"error": format!("Invalid item: {}", e)})),
|
||||
Json(json!({ "error": format!("Invalid item: {}", e) })),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
@@ -610,7 +636,7 @@ pub(super) async fn create_conversation_items(
|
||||
Err(e) => {
|
||||
return (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({"error": format!("Failed to create item: {}", e)})),
|
||||
Json(json!({ "error": format!("Failed to create item: {}", e) })),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
@@ -678,7 +704,9 @@ pub(super) async fn get_conversation_item(
|
||||
Err(e) => {
|
||||
return (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({"error": format!("Failed to get conversation: {}", e)})),
|
||||
Json(json!({
|
||||
"error": format!("Failed to get conversation: {}", e)
|
||||
})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
@@ -693,7 +721,9 @@ pub(super) async fn get_conversation_item(
|
||||
Err(e) => {
|
||||
return (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({"error": format!("Failed to check item link: {}", e)})),
|
||||
Json(json!({
|
||||
"error": format!("Failed to check item link: {}", e)
|
||||
})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
@@ -721,7 +751,7 @@ pub(super) async fn get_conversation_item(
|
||||
.into_response(),
|
||||
Err(e) => (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({"error": format!("Failed to get item: {}", e)})),
|
||||
Json(json!({ "error": format!("Failed to get item: {}", e) })),
|
||||
)
|
||||
.into_response(),
|
||||
}
|
||||
@@ -753,7 +783,9 @@ pub(super) async fn delete_conversation_item(
|
||||
Err(e) => {
|
||||
return (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({"error": format!("Failed to get conversation: {}", e)})),
|
||||
Json(json!({
|
||||
"error": format!("Failed to get conversation: {}", e)
|
||||
})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
@@ -773,7 +805,7 @@ pub(super) async fn delete_conversation_item(
|
||||
}
|
||||
Err(e) => (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({"error": format!("Failed to delete item: {}", e)})),
|
||||
Json(json!({ "error": format!("Failed to delete item: {}", e) })),
|
||||
)
|
||||
.into_response(),
|
||||
}
|
||||
|
||||
@@ -156,7 +156,7 @@ pub(super) fn patch_streaming_response_json(
|
||||
|
||||
// Attach conversation id for client response if present (final aggregated JSON)
|
||||
if let Some(conv_id) = original_body.conversation.clone() {
|
||||
obj.insert("conversation".to_string(), json!({"id": conv_id}));
|
||||
obj.insert("conversation".to_string(), json!({ "id": conv_id }));
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -234,7 +234,7 @@ pub(super) fn rewrite_streaming_block(
|
||||
|
||||
// Attach conversation id into streaming event response content with ordering
|
||||
if let Some(conv_id) = original_body.conversation.clone() {
|
||||
response_obj.insert("conversation".to_string(), json!({"id": conv_id}));
|
||||
response_obj.insert("conversation".to_string(), json!({ "id": conv_id }));
|
||||
changed = true;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -42,6 +42,7 @@ use crate::{
|
||||
},
|
||||
protocols::{
|
||||
chat::ChatCompletionRequest,
|
||||
classify::ClassifyRequest,
|
||||
completion::CompletionRequest,
|
||||
embedding::EmbeddingRequest,
|
||||
generate::GenerateRequest,
|
||||
@@ -828,7 +829,7 @@ impl crate::routers::RouterTrait for OpenAIRouter {
|
||||
.into_response(),
|
||||
Err(e) => (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({"error": format!("Failed to get response: {}", e)})),
|
||||
Json(json!({ "error": format!("Failed to get response: {}", e) })),
|
||||
)
|
||||
.into_response(),
|
||||
}
|
||||
@@ -882,6 +883,15 @@ impl crate::routers::RouterTrait for OpenAIRouter {
|
||||
(StatusCode::NOT_IMPLEMENTED, "Rerank not supported").into_response()
|
||||
}
|
||||
|
||||
async fn route_classify(
|
||||
&self,
|
||||
_headers: Option<&HeaderMap>,
|
||||
_body: &ClassifyRequest,
|
||||
_model_id: Option<&str>,
|
||||
) -> Response {
|
||||
(StatusCode::NOT_IMPLEMENTED, "Classify not supported").into_response()
|
||||
}
|
||||
|
||||
async fn create_conversation(&self, _headers: Option<&HeaderMap>, body: &Value) -> Response {
|
||||
create_conversation(&self.conversation_storage, body.clone()).await
|
||||
}
|
||||
|
||||
@@ -22,6 +22,7 @@ use crate::{
|
||||
core::{WorkerRegistry, WorkerType},
|
||||
protocols::{
|
||||
chat::ChatCompletionRequest,
|
||||
classify::ClassifyRequest,
|
||||
completion::CompletionRequest,
|
||||
embedding::EmbeddingRequest,
|
||||
generate::GenerateRequest,
|
||||
@@ -329,10 +330,7 @@ impl RouterTrait for RouterManager {
|
||||
} else {
|
||||
(
|
||||
StatusCode::OK,
|
||||
serde_json::json!({
|
||||
"models": models
|
||||
})
|
||||
.to_string(),
|
||||
serde_json::json!({ "models": models }).to_string(),
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
@@ -517,6 +515,25 @@ impl RouterTrait for RouterManager {
|
||||
}
|
||||
}
|
||||
|
||||
async fn route_classify(
|
||||
&self,
|
||||
headers: Option<&HeaderMap>,
|
||||
body: &ClassifyRequest,
|
||||
model_id: Option<&str>,
|
||||
) -> Response {
|
||||
let router = self.select_router_for_request(headers, Some(&body.model));
|
||||
|
||||
if let Some(router) = router {
|
||||
router.route_classify(headers, body, model_id).await
|
||||
} else {
|
||||
(
|
||||
StatusCode::NOT_FOUND,
|
||||
format!("Model '{}' not found or no router available", body.model),
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
}
|
||||
|
||||
fn router_type(&self) -> &'static str {
|
||||
"manager"
|
||||
}
|
||||
|
||||
@@ -37,6 +37,7 @@ use crate::{
|
||||
policies::PolicyRegistry,
|
||||
protocols::{
|
||||
chat::ChatCompletionRequest,
|
||||
classify::ClassifyRequest,
|
||||
completion::CompletionRequest,
|
||||
embedding::EmbeddingRequest,
|
||||
generate::GenerateRequest,
|
||||
@@ -270,6 +271,17 @@ async fn v1_embeddings(
|
||||
.await
|
||||
}
|
||||
|
||||
async fn v1_classify(
|
||||
State(state): State<Arc<AppState>>,
|
||||
headers: http::HeaderMap,
|
||||
Json(body): Json<ClassifyRequest>,
|
||||
) -> Response {
|
||||
state
|
||||
.router
|
||||
.route_classify(Some(&headers), &body, None)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn v1_responses_get(
|
||||
State(state): State<Arc<AppState>>,
|
||||
Path(response_id): Path<String>,
|
||||
@@ -534,13 +546,7 @@ async fn get_loads(State(state): State<Arc<AppState>>, _req: Request) -> Respons
|
||||
})
|
||||
.collect();
|
||||
|
||||
(
|
||||
StatusCode::OK,
|
||||
Json(json!({
|
||||
"workers": loads
|
||||
})),
|
||||
)
|
||||
.into_response()
|
||||
(StatusCode::OK, Json(json!({ "workers": loads }))).into_response()
|
||||
}
|
||||
|
||||
async fn create_worker(
|
||||
@@ -707,6 +713,7 @@ pub fn build_app(
|
||||
.route("/v1/rerank", post(v1_rerank))
|
||||
.route("/v1/responses", post(v1_responses))
|
||||
.route("/v1/embeddings", post(v1_embeddings))
|
||||
.route("/v1/classify", post(v1_classify))
|
||||
.route("/v1/responses/{response_id}", get(v1_responses_get))
|
||||
.route(
|
||||
"/v1/responses/{response_id}/cancel",
|
||||
|
||||
@@ -1617,7 +1617,7 @@ async fn test_conversation_items_max_limit() {
|
||||
"content": [{"type": "input_text", "text": format!("Message {}", i)}]
|
||||
}));
|
||||
}
|
||||
let create_items = serde_json::json!({"items": items});
|
||||
let create_items = serde_json::json!({ "items": items });
|
||||
|
||||
let items_resp = router
|
||||
.create_conversation_items(None, conv_id, &create_items)
|
||||
|
||||
Reference in New Issue
Block a user