[router] Basic OAI Response api (#10346)
This commit is contained in:
@@ -289,6 +289,14 @@ impl RouterTrait for GrpcPDRouter {
|
||||
(StatusCode::NOT_IMPLEMENTED).into_response()
|
||||
}
|
||||
|
||||
async fn route_responses(
|
||||
&self,
|
||||
_headers: Option<&HeaderMap>,
|
||||
_body: &crate::protocols::spec::ResponsesRequest,
|
||||
) -> Response {
|
||||
(StatusCode::NOT_IMPLEMENTED).into_response()
|
||||
}
|
||||
|
||||
async fn route_embeddings(&self, _headers: Option<&HeaderMap>, _body: Body) -> Response {
|
||||
(StatusCode::NOT_IMPLEMENTED).into_response()
|
||||
}
|
||||
|
||||
@@ -222,6 +222,14 @@ impl RouterTrait for GrpcRouter {
|
||||
(StatusCode::NOT_IMPLEMENTED).into_response()
|
||||
}
|
||||
|
||||
async fn route_responses(
|
||||
&self,
|
||||
_headers: Option<&HeaderMap>,
|
||||
_body: &crate::protocols::spec::ResponsesRequest,
|
||||
) -> Response {
|
||||
(StatusCode::NOT_IMPLEMENTED).into_response()
|
||||
}
|
||||
|
||||
async fn route_embeddings(&self, _headers: Option<&HeaderMap>, _body: Body) -> Response {
|
||||
(StatusCode::NOT_IMPLEMENTED).into_response()
|
||||
}
|
||||
|
||||
@@ -333,6 +333,18 @@ impl super::super::RouterTrait for OpenAIRouter {
|
||||
.into_response()
|
||||
}
|
||||
|
||||
async fn route_responses(
|
||||
&self,
|
||||
_headers: Option<&HeaderMap>,
|
||||
_body: &crate::protocols::spec::ResponsesRequest,
|
||||
) -> Response {
|
||||
(
|
||||
StatusCode::NOT_IMPLEMENTED,
|
||||
"Responses endpoint not implemented for OpenAI router",
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
|
||||
async fn flush_cache(&self) -> Response {
|
||||
(
|
||||
StatusCode::NOT_IMPLEMENTED,
|
||||
|
||||
@@ -9,8 +9,8 @@ use crate::core::{
|
||||
use crate::metrics::RouterMetrics;
|
||||
use crate::policies::LoadBalancingPolicy;
|
||||
use crate::protocols::spec::{
|
||||
ChatCompletionRequest, ChatMessage, CompletionRequest, GenerateRequest, StringOrArray,
|
||||
UserMessageContent,
|
||||
ChatCompletionRequest, ChatMessage, CompletionRequest, GenerateRequest, ResponsesRequest,
|
||||
StringOrArray, UserMessageContent,
|
||||
};
|
||||
use crate::routers::header_utils;
|
||||
use crate::routers::{RouterTrait, WorkerManagement};
|
||||
@@ -1930,6 +1930,18 @@ impl RouterTrait for PDRouter {
|
||||
self.execute_dual_dispatch(headers, body, context).await
|
||||
}
|
||||
|
||||
async fn route_responses(
|
||||
&self,
|
||||
_headers: Option<&HeaderMap>,
|
||||
_body: &ResponsesRequest,
|
||||
) -> Response {
|
||||
(
|
||||
StatusCode::NOT_IMPLEMENTED,
|
||||
"Responses endpoint not implemented for PD router",
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
|
||||
async fn route_embeddings(&self, _headers: Option<&HeaderMap>, _body: Body) -> Response {
|
||||
todo!()
|
||||
}
|
||||
|
||||
@@ -6,7 +6,7 @@ use crate::core::{
|
||||
use crate::metrics::RouterMetrics;
|
||||
use crate::policies::LoadBalancingPolicy;
|
||||
use crate::protocols::spec::{
|
||||
ChatCompletionRequest, CompletionRequest, GenerateRequest, GenerationRequest,
|
||||
ChatCompletionRequest, CompletionRequest, GenerateRequest, GenerationRequest, ResponsesRequest,
|
||||
};
|
||||
use crate::routers::header_utils;
|
||||
use crate::routers::{RouterTrait, WorkerManagement};
|
||||
@@ -1210,6 +1210,15 @@ impl RouterTrait for Router {
|
||||
.await
|
||||
}
|
||||
|
||||
async fn route_responses(
|
||||
&self,
|
||||
headers: Option<&HeaderMap>,
|
||||
body: &ResponsesRequest,
|
||||
) -> Response {
|
||||
self.route_typed_request(headers, body, "/v1/responses")
|
||||
.await
|
||||
}
|
||||
|
||||
async fn route_embeddings(&self, _headers: Option<&HeaderMap>, _body: Body) -> Response {
|
||||
todo!()
|
||||
}
|
||||
|
||||
@@ -9,7 +9,9 @@ use axum::{
|
||||
};
|
||||
use std::fmt::Debug;
|
||||
|
||||
use crate::protocols::spec::{ChatCompletionRequest, CompletionRequest, GenerateRequest};
|
||||
use crate::protocols::spec::{
|
||||
ChatCompletionRequest, CompletionRequest, GenerateRequest, ResponsesRequest,
|
||||
};
|
||||
|
||||
pub mod factory;
|
||||
pub mod grpc;
|
||||
@@ -78,6 +80,13 @@ pub trait RouterTrait: Send + Sync + Debug + WorkerManagement {
|
||||
body: &CompletionRequest,
|
||||
) -> Response;
|
||||
|
||||
/// Route a responses request
|
||||
async fn route_responses(
|
||||
&self,
|
||||
headers: Option<&HeaderMap>,
|
||||
body: &ResponsesRequest,
|
||||
) -> Response;
|
||||
|
||||
async fn route_embeddings(&self, headers: Option<&HeaderMap>, body: Body) -> Response;
|
||||
|
||||
async fn route_rerank(&self, headers: Option<&HeaderMap>, body: Body) -> Response;
|
||||
|
||||
@@ -2,7 +2,9 @@ use crate::config::RouterConfig;
|
||||
use crate::logging::{self, LoggingConfig};
|
||||
use crate::metrics::{self, PrometheusConfig};
|
||||
use crate::middleware::TokenBucket;
|
||||
use crate::protocols::spec::{ChatCompletionRequest, CompletionRequest, GenerateRequest};
|
||||
use crate::protocols::spec::{
|
||||
ChatCompletionRequest, CompletionRequest, GenerateRequest, ResponsesRequest,
|
||||
};
|
||||
use crate::reasoning_parser::ParserFactory;
|
||||
use crate::routers::{RouterFactory, RouterTrait};
|
||||
use crate::service_discovery::{start_service_discovery, ServiceDiscoveryConfig};
|
||||
@@ -150,6 +152,14 @@ async fn v1_completions(
|
||||
state.router.route_completion(Some(&headers), &body).await
|
||||
}
|
||||
|
||||
async fn v1_responses(
|
||||
State(state): State<Arc<AppState>>,
|
||||
headers: http::HeaderMap,
|
||||
Json(body): Json<ResponsesRequest>,
|
||||
) -> Response {
|
||||
state.router.route_responses(Some(&headers), &body).await
|
||||
}
|
||||
|
||||
// Worker management endpoints
|
||||
async fn add_worker(
|
||||
State(state): State<Arc<AppState>>,
|
||||
@@ -227,6 +237,7 @@ pub fn build_app(
|
||||
.route("/generate", post(generate))
|
||||
.route("/v1/chat/completions", post(v1_chat_completions))
|
||||
.route("/v1/completions", post(v1_completions))
|
||||
.route("/v1/responses", post(v1_responses))
|
||||
.route_layer(axum::middleware::from_fn_with_state(
|
||||
app_state.clone(),
|
||||
crate::middleware::concurrency_limit_middleware,
|
||||
|
||||
@@ -991,6 +991,91 @@ mod router_policy_tests {
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod responses_endpoint_tests {
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_v1_responses_non_streaming() {
|
||||
let ctx = TestContext::new(vec![MockWorkerConfig {
|
||||
port: 18950,
|
||||
worker_type: WorkerType::Regular,
|
||||
health_status: HealthStatus::Healthy,
|
||||
response_delay_ms: 0,
|
||||
fail_rate: 0.0,
|
||||
}])
|
||||
.await;
|
||||
|
||||
let app = ctx.create_app().await;
|
||||
|
||||
let payload = json!({
|
||||
"input": "Hello Responses API",
|
||||
"model": "mock-model",
|
||||
"stream": false
|
||||
});
|
||||
|
||||
let req = Request::builder()
|
||||
.method("POST")
|
||||
.uri("/v1/responses")
|
||||
.header(CONTENT_TYPE, "application/json")
|
||||
.body(Body::from(serde_json::to_string(&payload).unwrap()))
|
||||
.unwrap();
|
||||
|
||||
let resp = app.clone().oneshot(req).await.unwrap();
|
||||
assert_eq!(resp.status(), StatusCode::OK);
|
||||
|
||||
let body = axum::body::to_bytes(resp.into_body(), usize::MAX)
|
||||
.await
|
||||
.unwrap();
|
||||
let body_json: serde_json::Value = serde_json::from_slice(&body).unwrap();
|
||||
assert_eq!(body_json["object"], "response");
|
||||
assert_eq!(body_json["status"], "completed");
|
||||
|
||||
ctx.shutdown().await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_v1_responses_streaming() {
|
||||
let ctx = TestContext::new(vec![MockWorkerConfig {
|
||||
port: 18951,
|
||||
worker_type: WorkerType::Regular,
|
||||
health_status: HealthStatus::Healthy,
|
||||
response_delay_ms: 0,
|
||||
fail_rate: 0.0,
|
||||
}])
|
||||
.await;
|
||||
|
||||
let app = ctx.create_app().await;
|
||||
|
||||
let payload = json!({
|
||||
"input": "Hello Responses API",
|
||||
"model": "mock-model",
|
||||
"stream": true
|
||||
});
|
||||
|
||||
let req = Request::builder()
|
||||
.method("POST")
|
||||
.uri("/v1/responses")
|
||||
.header(CONTENT_TYPE, "application/json")
|
||||
.body(Body::from(serde_json::to_string(&payload).unwrap()))
|
||||
.unwrap();
|
||||
|
||||
let resp = app.clone().oneshot(req).await.unwrap();
|
||||
assert_eq!(resp.status(), StatusCode::OK);
|
||||
|
||||
// Check that content-type indicates SSE
|
||||
let headers = resp.headers().clone();
|
||||
let ct = headers
|
||||
.get("content-type")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.unwrap_or("");
|
||||
assert!(ct.contains("text/event-stream"));
|
||||
|
||||
// We don't fully consume the stream in this test harness.
|
||||
ctx.shutdown().await;
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod error_tests {
|
||||
use super::*;
|
||||
|
||||
@@ -81,6 +81,7 @@ impl MockWorker {
|
||||
.route("/generate", post(generate_handler))
|
||||
.route("/v1/chat/completions", post(chat_completions_handler))
|
||||
.route("/v1/completions", post(completions_handler))
|
||||
.route("/v1/responses", post(responses_handler))
|
||||
.route("/flush_cache", post(flush_cache_handler))
|
||||
.route("/v1/models", get(v1_models_handler))
|
||||
.with_state(config);
|
||||
@@ -548,6 +549,91 @@ async fn completions_handler(
|
||||
}
|
||||
}
|
||||
|
||||
async fn responses_handler(
|
||||
State(config): State<Arc<RwLock<MockWorkerConfig>>>,
|
||||
Json(payload): Json<serde_json::Value>,
|
||||
) -> Response {
|
||||
let config = config.read().await;
|
||||
|
||||
if should_fail(&config).await {
|
||||
return (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({
|
||||
"error": {
|
||||
"message": "Random failure for testing",
|
||||
"type": "internal_error",
|
||||
"code": "internal_error"
|
||||
}
|
||||
})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
|
||||
if config.response_delay_ms > 0 {
|
||||
tokio::time::sleep(tokio::time::Duration::from_millis(config.response_delay_ms)).await;
|
||||
}
|
||||
|
||||
let is_stream = payload
|
||||
.get("stream")
|
||||
.and_then(|v| v.as_bool())
|
||||
.unwrap_or(false);
|
||||
|
||||
let timestamp = SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.unwrap()
|
||||
.as_secs() as i64;
|
||||
|
||||
if is_stream {
|
||||
let request_id = format!("resp-{}", Uuid::new_v4());
|
||||
|
||||
let stream = stream::once(async move {
|
||||
let chunk = json!({
|
||||
"id": request_id,
|
||||
"object": "response",
|
||||
"created_at": timestamp,
|
||||
"model": "mock-model",
|
||||
"status": "in_progress",
|
||||
"output": [{
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"content": [{
|
||||
"type": "output_text",
|
||||
"text": "This is a mock responses streamed output."
|
||||
}]
|
||||
}]
|
||||
});
|
||||
Ok::<_, Infallible>(Event::default().data(chunk.to_string()))
|
||||
})
|
||||
.chain(stream::once(async { Ok(Event::default().data("[DONE]")) }));
|
||||
|
||||
Sse::new(stream)
|
||||
.keep_alive(KeepAlive::default())
|
||||
.into_response()
|
||||
} else {
|
||||
Json(json!({
|
||||
"id": format!("resp-{}", Uuid::new_v4()),
|
||||
"object": "response",
|
||||
"created_at": timestamp,
|
||||
"model": "mock-model",
|
||||
"output": [{
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"content": [{
|
||||
"type": "output_text",
|
||||
"text": "This is a mock responses output."
|
||||
}]
|
||||
}],
|
||||
"status": "completed",
|
||||
"usage": {
|
||||
"input_tokens": 10,
|
||||
"output_tokens": 5,
|
||||
"total_tokens": 15
|
||||
}
|
||||
}))
|
||||
.into_response()
|
||||
}
|
||||
}
|
||||
|
||||
async fn flush_cache_handler(State(config): State<Arc<RwLock<MockWorkerConfig>>>) -> Response {
|
||||
let config = config.read().await;
|
||||
|
||||
|
||||
Reference in New Issue
Block a user