[router] Add get and cancel method for response api (#10387)
This commit is contained in:
@@ -301,6 +301,14 @@ impl RouterTrait for GrpcPDRouter {
|
|||||||
(StatusCode::NOT_IMPLEMENTED).into_response()
|
(StatusCode::NOT_IMPLEMENTED).into_response()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn get_response(&self, _headers: Option<&HeaderMap>, _response_id: &str) -> Response {
|
||||||
|
(StatusCode::NOT_IMPLEMENTED).into_response()
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn cancel_response(&self, _headers: Option<&HeaderMap>, _response_id: &str) -> Response {
|
||||||
|
(StatusCode::NOT_IMPLEMENTED).into_response()
|
||||||
|
}
|
||||||
|
|
||||||
async fn route_embeddings(&self, _headers: Option<&HeaderMap>, _body: Body) -> Response {
|
async fn route_embeddings(&self, _headers: Option<&HeaderMap>, _body: Body) -> Response {
|
||||||
(StatusCode::NOT_IMPLEMENTED).into_response()
|
(StatusCode::NOT_IMPLEMENTED).into_response()
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -234,6 +234,14 @@ impl RouterTrait for GrpcRouter {
|
|||||||
(StatusCode::NOT_IMPLEMENTED).into_response()
|
(StatusCode::NOT_IMPLEMENTED).into_response()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn get_response(&self, _headers: Option<&HeaderMap>, _response_id: &str) -> Response {
|
||||||
|
(StatusCode::NOT_IMPLEMENTED).into_response()
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn cancel_response(&self, _headers: Option<&HeaderMap>, _response_id: &str) -> Response {
|
||||||
|
(StatusCode::NOT_IMPLEMENTED).into_response()
|
||||||
|
}
|
||||||
|
|
||||||
async fn route_embeddings(&self, _headers: Option<&HeaderMap>, _body: Body) -> Response {
|
async fn route_embeddings(&self, _headers: Option<&HeaderMap>, _body: Body) -> Response {
|
||||||
(StatusCode::NOT_IMPLEMENTED).into_response()
|
(StatusCode::NOT_IMPLEMENTED).into_response()
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -351,6 +351,22 @@ impl super::super::RouterTrait for OpenAIRouter {
|
|||||||
.into_response()
|
.into_response()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn get_response(&self, _headers: Option<&HeaderMap>, _response_id: &str) -> Response {
|
||||||
|
(
|
||||||
|
StatusCode::NOT_IMPLEMENTED,
|
||||||
|
"Responses retrieve endpoint not implemented for OpenAI router",
|
||||||
|
)
|
||||||
|
.into_response()
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn cancel_response(&self, _headers: Option<&HeaderMap>, _response_id: &str) -> Response {
|
||||||
|
(
|
||||||
|
StatusCode::NOT_IMPLEMENTED,
|
||||||
|
"Responses cancel endpoint not implemented for OpenAI router",
|
||||||
|
)
|
||||||
|
.into_response()
|
||||||
|
}
|
||||||
|
|
||||||
async fn flush_cache(&self) -> Response {
|
async fn flush_cache(&self) -> Response {
|
||||||
(
|
(
|
||||||
StatusCode::NOT_IMPLEMENTED,
|
StatusCode::NOT_IMPLEMENTED,
|
||||||
|
|||||||
@@ -1922,6 +1922,22 @@ impl RouterTrait for PDRouter {
|
|||||||
.into_response()
|
.into_response()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn get_response(&self, _headers: Option<&HeaderMap>, _response_id: &str) -> Response {
|
||||||
|
(
|
||||||
|
StatusCode::NOT_IMPLEMENTED,
|
||||||
|
"Responses retrieve endpoint not implemented for PD router",
|
||||||
|
)
|
||||||
|
.into_response()
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn cancel_response(&self, _headers: Option<&HeaderMap>, _response_id: &str) -> Response {
|
||||||
|
(
|
||||||
|
StatusCode::NOT_IMPLEMENTED,
|
||||||
|
"Responses cancel endpoint not implemented for PD router",
|
||||||
|
)
|
||||||
|
.into_response()
|
||||||
|
}
|
||||||
|
|
||||||
async fn route_embeddings(&self, _headers: Option<&HeaderMap>, _body: Body) -> Response {
|
async fn route_embeddings(&self, _headers: Option<&HeaderMap>, _body: Body) -> Response {
|
||||||
todo!()
|
todo!()
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -15,7 +15,9 @@ use axum::body::to_bytes;
|
|||||||
use axum::{
|
use axum::{
|
||||||
body::Body,
|
body::Body,
|
||||||
extract::Request,
|
extract::Request,
|
||||||
http::{header::CONTENT_LENGTH, header::CONTENT_TYPE, HeaderMap, HeaderValue, StatusCode},
|
http::{
|
||||||
|
header::CONTENT_LENGTH, header::CONTENT_TYPE, HeaderMap, HeaderValue, Method, StatusCode,
|
||||||
|
},
|
||||||
response::{IntoResponse, Response},
|
response::{IntoResponse, Response},
|
||||||
Json,
|
Json,
|
||||||
};
|
};
|
||||||
@@ -600,6 +602,114 @@ impl Router {
|
|||||||
response
|
response
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Helper: return base worker URL (strips DP suffix when enabled)
|
||||||
|
fn worker_base_url(&self, worker_url: &str) -> String {
|
||||||
|
if self.dp_aware {
|
||||||
|
if let Ok((prefix, _)) = Self::extract_dp_rank(worker_url) {
|
||||||
|
return prefix.to_string();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
worker_url.to_string()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generic simple routing for GET/POST without JSON body
|
||||||
|
async fn route_simple_request(
|
||||||
|
&self,
|
||||||
|
headers: Option<&HeaderMap>,
|
||||||
|
endpoint: &str,
|
||||||
|
method: Method,
|
||||||
|
) -> Response {
|
||||||
|
// TODO: currently the sglang worker is using in-memory state management, so this implementation has to fan out to all workers.
|
||||||
|
// Eventually, we need to have router to manage the chat history with a proper database, will update this implementation accordingly.
|
||||||
|
let worker_urls = self.get_worker_urls();
|
||||||
|
if worker_urls.is_empty() {
|
||||||
|
return (StatusCode::SERVICE_UNAVAILABLE, "No available workers").into_response();
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut last_response: Option<Response> = None;
|
||||||
|
for worker_url in worker_urls {
|
||||||
|
let base = self.worker_base_url(&worker_url);
|
||||||
|
|
||||||
|
let url = format!("{}/{}", base, endpoint);
|
||||||
|
let mut request_builder = match method {
|
||||||
|
Method::GET => self.client.get(url),
|
||||||
|
Method::POST => self.client.post(url),
|
||||||
|
_ => {
|
||||||
|
return (
|
||||||
|
StatusCode::METHOD_NOT_ALLOWED,
|
||||||
|
"Unsupported method for simple routing",
|
||||||
|
)
|
||||||
|
.into_response()
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
if let Some(hdrs) = headers {
|
||||||
|
for (name, value) in hdrs {
|
||||||
|
let name_lc = name.as_str().to_lowercase();
|
||||||
|
if name_lc != "content-type" && name_lc != "content-length" {
|
||||||
|
request_builder = request_builder.header(name, value);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
match request_builder.send().await {
|
||||||
|
Ok(res) => {
|
||||||
|
let status = StatusCode::from_u16(res.status().as_u16())
|
||||||
|
.unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
|
||||||
|
let response_headers = header_utils::preserve_response_headers(res.headers());
|
||||||
|
match res.bytes().await {
|
||||||
|
Ok(body) => {
|
||||||
|
let mut response = Response::new(axum::body::Body::from(body));
|
||||||
|
*response.status_mut() = status;
|
||||||
|
*response.headers_mut() = response_headers;
|
||||||
|
if status.is_success() {
|
||||||
|
return response;
|
||||||
|
}
|
||||||
|
last_response = Some(response);
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
last_response = Some(
|
||||||
|
(
|
||||||
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
|
format!("Failed to read response: {}", e),
|
||||||
|
)
|
||||||
|
.into_response(),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
last_response = Some(
|
||||||
|
(
|
||||||
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
|
format!("Request failed: {}", e),
|
||||||
|
)
|
||||||
|
.into_response(),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
last_response
|
||||||
|
.unwrap_or_else(|| (StatusCode::BAD_GATEWAY, "No worker response").into_response())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Route a GET request with provided headers to a specific endpoint
|
||||||
|
async fn route_get_request(&self, headers: Option<&HeaderMap>, endpoint: &str) -> Response {
|
||||||
|
self.route_simple_request(headers, endpoint, Method::GET)
|
||||||
|
.await
|
||||||
|
}
|
||||||
|
|
||||||
|
// Route a POST request with empty body to a specific endpoint
|
||||||
|
async fn route_post_empty_request(
|
||||||
|
&self,
|
||||||
|
headers: Option<&HeaderMap>,
|
||||||
|
endpoint: &str,
|
||||||
|
) -> Response {
|
||||||
|
self.route_simple_request(headers, endpoint, Method::POST)
|
||||||
|
.await
|
||||||
|
}
|
||||||
|
|
||||||
// TODO (rui): Better accommodate to the Worker abstraction
|
// TODO (rui): Better accommodate to the Worker abstraction
|
||||||
fn extract_dp_rank(worker_url: &str) -> Result<(&str, usize), String> {
|
fn extract_dp_rank(worker_url: &str) -> Result<(&str, usize), String> {
|
||||||
let parts: Vec<&str> = worker_url.split('@').collect();
|
let parts: Vec<&str> = worker_url.split('@').collect();
|
||||||
@@ -1310,6 +1420,16 @@ impl RouterTrait for Router {
|
|||||||
.await
|
.await
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn get_response(&self, headers: Option<&HeaderMap>, response_id: &str) -> Response {
|
||||||
|
let endpoint = format!("v1/responses/{}", response_id);
|
||||||
|
self.route_get_request(headers, &endpoint).await
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn cancel_response(&self, headers: Option<&HeaderMap>, response_id: &str) -> Response {
|
||||||
|
let endpoint = format!("v1/responses/{}/cancel", response_id);
|
||||||
|
self.route_post_empty_request(headers, &endpoint).await
|
||||||
|
}
|
||||||
|
|
||||||
async fn route_embeddings(&self, _headers: Option<&HeaderMap>, _body: Body) -> Response {
|
async fn route_embeddings(&self, _headers: Option<&HeaderMap>, _body: Body) -> Response {
|
||||||
todo!()
|
todo!()
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -95,6 +95,34 @@ pub trait RouterTrait: Send + Sync + Debug + WorkerManagement {
|
|||||||
model_id: Option<&str>,
|
model_id: Option<&str>,
|
||||||
) -> Response;
|
) -> Response;
|
||||||
|
|
||||||
|
/// Retrieve a stored/background response by id
|
||||||
|
async fn get_response(&self, headers: Option<&HeaderMap>, response_id: &str) -> Response;
|
||||||
|
|
||||||
|
/// Cancel a background response by id
|
||||||
|
async fn cancel_response(&self, headers: Option<&HeaderMap>, response_id: &str) -> Response;
|
||||||
|
|
||||||
|
/// Delete a response by id
|
||||||
|
async fn delete_response(&self, _headers: Option<&HeaderMap>, _response_id: &str) -> Response {
|
||||||
|
(
|
||||||
|
StatusCode::NOT_IMPLEMENTED,
|
||||||
|
"Responses delete endpoint not implemented",
|
||||||
|
)
|
||||||
|
.into_response()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// List input items of a response by id
|
||||||
|
async fn list_response_input_items(
|
||||||
|
&self,
|
||||||
|
_headers: Option<&HeaderMap>,
|
||||||
|
_response_id: &str,
|
||||||
|
) -> Response {
|
||||||
|
(
|
||||||
|
StatusCode::NOT_IMPLEMENTED,
|
||||||
|
"Responses list input items endpoint not implemented",
|
||||||
|
)
|
||||||
|
.into_response()
|
||||||
|
}
|
||||||
|
|
||||||
async fn route_embeddings(&self, headers: Option<&HeaderMap>, body: Body) -> Response;
|
async fn route_embeddings(&self, headers: Option<&HeaderMap>, body: Body) -> Response;
|
||||||
|
|
||||||
async fn route_rerank(
|
async fn route_rerank(
|
||||||
|
|||||||
@@ -16,10 +16,10 @@ use crate::service_discovery::{start_service_discovery, ServiceDiscoveryConfig};
|
|||||||
use crate::tokenizer::{factory as tokenizer_factory, traits::Tokenizer};
|
use crate::tokenizer::{factory as tokenizer_factory, traits::Tokenizer};
|
||||||
use crate::tool_parser::ParserRegistry;
|
use crate::tool_parser::ParserRegistry;
|
||||||
use axum::{
|
use axum::{
|
||||||
extract::{Query, Request, State},
|
extract::{Path, Query, Request, State},
|
||||||
http::StatusCode,
|
http::StatusCode,
|
||||||
response::{IntoResponse, Response},
|
response::{IntoResponse, Response},
|
||||||
routing::{get, post},
|
routing::{delete, get, post},
|
||||||
Json, Router,
|
Json, Router,
|
||||||
};
|
};
|
||||||
use reqwest::Client;
|
use reqwest::Client;
|
||||||
@@ -208,6 +208,52 @@ async fn v1_responses(
|
|||||||
.await
|
.await
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn v1_responses_get(
|
||||||
|
State(state): State<Arc<AppState>>,
|
||||||
|
Path(response_id): Path<String>,
|
||||||
|
headers: http::HeaderMap,
|
||||||
|
) -> Response {
|
||||||
|
state
|
||||||
|
.router
|
||||||
|
.get_response(Some(&headers), &response_id)
|
||||||
|
.await
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn v1_responses_cancel(
|
||||||
|
State(state): State<Arc<AppState>>,
|
||||||
|
Path(response_id): Path<String>,
|
||||||
|
headers: http::HeaderMap,
|
||||||
|
) -> Response {
|
||||||
|
state
|
||||||
|
.router
|
||||||
|
.cancel_response(Some(&headers), &response_id)
|
||||||
|
.await
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn v1_responses_delete(
|
||||||
|
State(state): State<Arc<AppState>>,
|
||||||
|
Path(response_id): Path<String>,
|
||||||
|
headers: http::HeaderMap,
|
||||||
|
) -> Response {
|
||||||
|
// Python server does not support this yet
|
||||||
|
state
|
||||||
|
.router
|
||||||
|
.delete_response(Some(&headers), &response_id)
|
||||||
|
.await
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn v1_responses_list_input_items(
|
||||||
|
State(state): State<Arc<AppState>>,
|
||||||
|
Path(response_id): Path<String>,
|
||||||
|
headers: http::HeaderMap,
|
||||||
|
) -> Response {
|
||||||
|
// Python server does not support this yet
|
||||||
|
state
|
||||||
|
.router
|
||||||
|
.list_response_input_items(Some(&headers), &response_id)
|
||||||
|
.await
|
||||||
|
}
|
||||||
|
|
||||||
// Worker management endpoints
|
// Worker management endpoints
|
||||||
async fn add_worker(
|
async fn add_worker(
|
||||||
State(state): State<Arc<AppState>>,
|
State(state): State<Arc<AppState>>,
|
||||||
@@ -419,6 +465,16 @@ pub fn build_app(
|
|||||||
.route("/rerank", post(rerank))
|
.route("/rerank", post(rerank))
|
||||||
.route("/v1/rerank", post(v1_rerank))
|
.route("/v1/rerank", post(v1_rerank))
|
||||||
.route("/v1/responses", post(v1_responses))
|
.route("/v1/responses", post(v1_responses))
|
||||||
|
.route("/v1/responses/{response_id}", get(v1_responses_get))
|
||||||
|
.route(
|
||||||
|
"/v1/responses/{response_id}/cancel",
|
||||||
|
post(v1_responses_cancel),
|
||||||
|
)
|
||||||
|
.route("/v1/responses/{response_id}", delete(v1_responses_delete))
|
||||||
|
.route(
|
||||||
|
"/v1/responses/{response_id}/input",
|
||||||
|
get(v1_responses_list_input_items),
|
||||||
|
)
|
||||||
.route_layer(axum::middleware::from_fn_with_state(
|
.route_layer(axum::middleware::from_fn_with_state(
|
||||||
app_state.clone(),
|
app_state.clone(),
|
||||||
crate::middleware::concurrency_limit_middleware,
|
crate::middleware::concurrency_limit_middleware,
|
||||||
|
|||||||
@@ -994,6 +994,7 @@ mod router_policy_tests {
|
|||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod responses_endpoint_tests {
|
mod responses_endpoint_tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
|
use reqwest::Client as HttpClient;
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_v1_responses_non_streaming() {
|
async fn test_v1_responses_non_streaming() {
|
||||||
@@ -1074,6 +1075,207 @@ mod responses_endpoint_tests {
|
|||||||
// We don't fully consume the stream in this test harness.
|
// We don't fully consume the stream in this test harness.
|
||||||
ctx.shutdown().await;
|
ctx.shutdown().await;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_v1_responses_get() {
|
||||||
|
let ctx = TestContext::new(vec![MockWorkerConfig {
|
||||||
|
port: 18952,
|
||||||
|
worker_type: WorkerType::Regular,
|
||||||
|
health_status: HealthStatus::Healthy,
|
||||||
|
response_delay_ms: 0,
|
||||||
|
fail_rate: 0.0,
|
||||||
|
}])
|
||||||
|
.await;
|
||||||
|
|
||||||
|
let app = ctx.create_app().await;
|
||||||
|
|
||||||
|
// First create a response to obtain an id
|
||||||
|
let payload = json!({
|
||||||
|
"input": "Hello Responses API",
|
||||||
|
"model": "mock-model",
|
||||||
|
"stream": false
|
||||||
|
});
|
||||||
|
let req = Request::builder()
|
||||||
|
.method("POST")
|
||||||
|
.uri("/v1/responses")
|
||||||
|
.header(CONTENT_TYPE, "application/json")
|
||||||
|
.body(Body::from(serde_json::to_string(&payload).unwrap()))
|
||||||
|
.unwrap();
|
||||||
|
let resp = app.clone().oneshot(req).await.unwrap();
|
||||||
|
assert_eq!(resp.status(), StatusCode::OK);
|
||||||
|
let body = axum::body::to_bytes(resp.into_body(), usize::MAX)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
let body_json: serde_json::Value = serde_json::from_slice(&body).unwrap();
|
||||||
|
let resp_id = body_json["id"].as_str().unwrap().to_string();
|
||||||
|
|
||||||
|
// Retrieve the response
|
||||||
|
let req = Request::builder()
|
||||||
|
.method("GET")
|
||||||
|
.uri(format!("/v1/responses/{}", resp_id))
|
||||||
|
.body(Body::empty())
|
||||||
|
.unwrap();
|
||||||
|
let resp = app.clone().oneshot(req).await.unwrap();
|
||||||
|
assert_eq!(resp.status(), StatusCode::OK);
|
||||||
|
let body = axum::body::to_bytes(resp.into_body(), usize::MAX)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
let get_json: serde_json::Value = serde_json::from_slice(&body).unwrap();
|
||||||
|
assert_eq!(get_json["object"], "response");
|
||||||
|
|
||||||
|
ctx.shutdown().await;
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_v1_responses_cancel() {
|
||||||
|
let ctx = TestContext::new(vec![MockWorkerConfig {
|
||||||
|
port: 18953,
|
||||||
|
worker_type: WorkerType::Regular,
|
||||||
|
health_status: HealthStatus::Healthy,
|
||||||
|
response_delay_ms: 0,
|
||||||
|
fail_rate: 0.0,
|
||||||
|
}])
|
||||||
|
.await;
|
||||||
|
|
||||||
|
let app = ctx.create_app().await;
|
||||||
|
|
||||||
|
// First create a response to obtain an id
|
||||||
|
let payload = json!({
|
||||||
|
"input": "Hello Responses API",
|
||||||
|
"model": "mock-model",
|
||||||
|
"stream": false
|
||||||
|
});
|
||||||
|
let req = Request::builder()
|
||||||
|
.method("POST")
|
||||||
|
.uri("/v1/responses")
|
||||||
|
.header(CONTENT_TYPE, "application/json")
|
||||||
|
.body(Body::from(serde_json::to_string(&payload).unwrap()))
|
||||||
|
.unwrap();
|
||||||
|
let resp = app.clone().oneshot(req).await.unwrap();
|
||||||
|
assert_eq!(resp.status(), StatusCode::OK);
|
||||||
|
let body = axum::body::to_bytes(resp.into_body(), usize::MAX)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
let body_json: serde_json::Value = serde_json::from_slice(&body).unwrap();
|
||||||
|
let resp_id = body_json["id"].as_str().unwrap().to_string();
|
||||||
|
|
||||||
|
// Cancel the response
|
||||||
|
let req = Request::builder()
|
||||||
|
.method("POST")
|
||||||
|
.uri(format!("/v1/responses/{}/cancel", resp_id))
|
||||||
|
.body(Body::empty())
|
||||||
|
.unwrap();
|
||||||
|
let resp = app.clone().oneshot(req).await.unwrap();
|
||||||
|
assert_eq!(resp.status(), StatusCode::OK);
|
||||||
|
let body = axum::body::to_bytes(resp.into_body(), usize::MAX)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
let cancel_json: serde_json::Value = serde_json::from_slice(&body).unwrap();
|
||||||
|
assert_eq!(cancel_json["status"], "cancelled");
|
||||||
|
|
||||||
|
ctx.shutdown().await;
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_v1_responses_delete_and_list_not_implemented() {
|
||||||
|
let ctx = TestContext::new(vec![MockWorkerConfig {
|
||||||
|
port: 18954,
|
||||||
|
worker_type: WorkerType::Regular,
|
||||||
|
health_status: HealthStatus::Healthy,
|
||||||
|
response_delay_ms: 0,
|
||||||
|
fail_rate: 0.0,
|
||||||
|
}])
|
||||||
|
.await;
|
||||||
|
|
||||||
|
let app = ctx.create_app().await;
|
||||||
|
|
||||||
|
// Use an arbitrary id for delete/list
|
||||||
|
let resp_id = "resp-test-123";
|
||||||
|
|
||||||
|
let req = Request::builder()
|
||||||
|
.method("DELETE")
|
||||||
|
.uri(format!("/v1/responses/{}", resp_id))
|
||||||
|
.body(Body::empty())
|
||||||
|
.unwrap();
|
||||||
|
let resp = app.clone().oneshot(req).await.unwrap();
|
||||||
|
assert_eq!(resp.status(), StatusCode::NOT_IMPLEMENTED);
|
||||||
|
|
||||||
|
let req = Request::builder()
|
||||||
|
.method("GET")
|
||||||
|
.uri(format!("/v1/responses/{}/input", resp_id))
|
||||||
|
.body(Body::empty())
|
||||||
|
.unwrap();
|
||||||
|
let resp = app.clone().oneshot(req).await.unwrap();
|
||||||
|
assert_eq!(resp.status(), StatusCode::NOT_IMPLEMENTED);
|
||||||
|
|
||||||
|
ctx.shutdown().await;
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_v1_responses_get_multi_worker_fanout() {
|
||||||
|
// Start two mock workers
|
||||||
|
let ctx = TestContext::new(vec![
|
||||||
|
MockWorkerConfig {
|
||||||
|
port: 18960,
|
||||||
|
worker_type: WorkerType::Regular,
|
||||||
|
health_status: HealthStatus::Healthy,
|
||||||
|
response_delay_ms: 0,
|
||||||
|
fail_rate: 0.0,
|
||||||
|
},
|
||||||
|
MockWorkerConfig {
|
||||||
|
port: 18961,
|
||||||
|
worker_type: WorkerType::Regular,
|
||||||
|
health_status: HealthStatus::Healthy,
|
||||||
|
response_delay_ms: 0,
|
||||||
|
fail_rate: 0.0,
|
||||||
|
},
|
||||||
|
])
|
||||||
|
.await;
|
||||||
|
|
||||||
|
let app = ctx.create_app().await;
|
||||||
|
|
||||||
|
// Create a background response with a known id
|
||||||
|
let rid = format!("resp_{}", 18960); // arbitrary unique id
|
||||||
|
let payload = json!({
|
||||||
|
"input": "Hello Responses API",
|
||||||
|
"model": "mock-model",
|
||||||
|
"background": true,
|
||||||
|
"store": true,
|
||||||
|
"request_id": rid,
|
||||||
|
});
|
||||||
|
|
||||||
|
let req = Request::builder()
|
||||||
|
.method("POST")
|
||||||
|
.uri("/v1/responses")
|
||||||
|
.header(CONTENT_TYPE, "application/json")
|
||||||
|
.body(Body::from(serde_json::to_string(&payload).unwrap()))
|
||||||
|
.unwrap();
|
||||||
|
let resp = app.clone().oneshot(req).await.unwrap();
|
||||||
|
assert_eq!(resp.status(), StatusCode::OK);
|
||||||
|
|
||||||
|
// Using the router, GET should succeed by fanning out across workers
|
||||||
|
let req = Request::builder()
|
||||||
|
.method("GET")
|
||||||
|
.uri(format!("/v1/responses/{}", rid))
|
||||||
|
.body(Body::empty())
|
||||||
|
.unwrap();
|
||||||
|
let resp = app.clone().oneshot(req).await.unwrap();
|
||||||
|
assert_eq!(resp.status(), StatusCode::OK);
|
||||||
|
|
||||||
|
// Validate only one worker holds the metadata: direct calls
|
||||||
|
let client = HttpClient::new();
|
||||||
|
let mut ok_count = 0usize;
|
||||||
|
for url in ctx.router.get_worker_urls() {
|
||||||
|
let get_url = format!("{}/v1/responses/{}", url, rid);
|
||||||
|
let res = client.get(get_url).send().await.unwrap();
|
||||||
|
if res.status() == StatusCode::OK {
|
||||||
|
ok_count += 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
assert_eq!(ok_count, 1, "exactly one worker should store the response");
|
||||||
|
|
||||||
|
ctx.shutdown().await;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
#![allow(dead_code)]
|
#![allow(dead_code)]
|
||||||
|
|
||||||
use axum::{
|
use axum::{
|
||||||
extract::{Json, State},
|
extract::{Json, Path, State},
|
||||||
http::StatusCode,
|
http::StatusCode,
|
||||||
response::sse::{Event, KeepAlive},
|
response::sse::{Event, KeepAlive},
|
||||||
response::{IntoResponse, Response, Sse},
|
response::{IntoResponse, Response, Sse},
|
||||||
@@ -11,8 +11,9 @@ use axum::{
|
|||||||
};
|
};
|
||||||
use futures_util::stream::{self, StreamExt};
|
use futures_util::stream::{self, StreamExt};
|
||||||
use serde_json::json;
|
use serde_json::json;
|
||||||
|
use std::collections::{HashMap, HashSet};
|
||||||
use std::convert::Infallible;
|
use std::convert::Infallible;
|
||||||
use std::sync::Arc;
|
use std::sync::{Arc, Mutex, OnceLock};
|
||||||
use std::time::{SystemTime, UNIX_EPOCH};
|
use std::time::{SystemTime, UNIX_EPOCH};
|
||||||
use tokio::sync::RwLock;
|
use tokio::sync::RwLock;
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
@@ -83,6 +84,11 @@ impl MockWorker {
|
|||||||
.route("/v1/completions", post(completions_handler))
|
.route("/v1/completions", post(completions_handler))
|
||||||
.route("/v1/rerank", post(rerank_handler))
|
.route("/v1/rerank", post(rerank_handler))
|
||||||
.route("/v1/responses", post(responses_handler))
|
.route("/v1/responses", post(responses_handler))
|
||||||
|
.route("/v1/responses/{response_id}", get(responses_get_handler))
|
||||||
|
.route(
|
||||||
|
"/v1/responses/{response_id}/cancel",
|
||||||
|
post(responses_cancel_handler),
|
||||||
|
)
|
||||||
.route("/flush_cache", post(flush_cache_handler))
|
.route("/flush_cache", post(flush_cache_handler))
|
||||||
.route("/v1/models", get(v1_models_handler))
|
.route("/v1/models", get(v1_models_handler))
|
||||||
.with_state(config);
|
.with_state(config);
|
||||||
@@ -584,6 +590,21 @@ async fn responses_handler(
|
|||||||
.unwrap()
|
.unwrap()
|
||||||
.as_secs() as i64;
|
.as_secs() as i64;
|
||||||
|
|
||||||
|
// Background storage simulation
|
||||||
|
let is_background = payload
|
||||||
|
.get("background")
|
||||||
|
.and_then(|v| v.as_bool())
|
||||||
|
.unwrap_or(false);
|
||||||
|
let req_id = payload
|
||||||
|
.get("request_id")
|
||||||
|
.and_then(|v| v.as_str())
|
||||||
|
.map(|s| s.to_string());
|
||||||
|
if is_background {
|
||||||
|
if let Some(id) = &req_id {
|
||||||
|
store_response_for_port(config.port, id);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if is_stream {
|
if is_stream {
|
||||||
let request_id = format!("resp-{}", Uuid::new_v4());
|
let request_id = format!("resp-{}", Uuid::new_v4());
|
||||||
|
|
||||||
@@ -610,6 +631,18 @@ async fn responses_handler(
|
|||||||
Sse::new(stream)
|
Sse::new(stream)
|
||||||
.keep_alive(KeepAlive::default())
|
.keep_alive(KeepAlive::default())
|
||||||
.into_response()
|
.into_response()
|
||||||
|
} else if is_background {
|
||||||
|
let rid = req_id.unwrap_or_else(|| format!("resp-{}", Uuid::new_v4()));
|
||||||
|
Json(json!({
|
||||||
|
"id": rid,
|
||||||
|
"object": "response",
|
||||||
|
"created_at": timestamp,
|
||||||
|
"model": "mock-model",
|
||||||
|
"output": [],
|
||||||
|
"status": "queued",
|
||||||
|
"usage": null
|
||||||
|
}))
|
||||||
|
.into_response()
|
||||||
} else {
|
} else {
|
||||||
Json(json!({
|
Json(json!({
|
||||||
"id": format!("resp-{}", Uuid::new_v4()),
|
"id": format!("resp-{}", Uuid::new_v4()),
|
||||||
@@ -688,6 +721,95 @@ async fn v1_models_handler(State(config): State<Arc<RwLock<MockWorkerConfig>>>)
|
|||||||
.into_response()
|
.into_response()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn responses_get_handler(
|
||||||
|
State(config): State<Arc<RwLock<MockWorkerConfig>>>,
|
||||||
|
Path(response_id): Path<String>,
|
||||||
|
) -> Response {
|
||||||
|
let config = config.read().await;
|
||||||
|
if should_fail(&config).await {
|
||||||
|
return (
|
||||||
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
|
Json(json!({ "error": "Random failure for testing" })),
|
||||||
|
)
|
||||||
|
.into_response();
|
||||||
|
}
|
||||||
|
let timestamp = SystemTime::now()
|
||||||
|
.duration_since(UNIX_EPOCH)
|
||||||
|
.unwrap()
|
||||||
|
.as_secs() as i64;
|
||||||
|
// Only return 200 if this worker "stores" the response id
|
||||||
|
if response_exists_for_port(config.port, &response_id) {
|
||||||
|
Json(json!({
|
||||||
|
"id": response_id,
|
||||||
|
"object": "response",
|
||||||
|
"created_at": timestamp,
|
||||||
|
"model": "mock-model",
|
||||||
|
"output": [],
|
||||||
|
"status": "completed",
|
||||||
|
"usage": {
|
||||||
|
"input_tokens": 0,
|
||||||
|
"output_tokens": 0,
|
||||||
|
"total_tokens": 0
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
.into_response()
|
||||||
|
} else {
|
||||||
|
StatusCode::NOT_FOUND.into_response()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn responses_cancel_handler(
|
||||||
|
State(config): State<Arc<RwLock<MockWorkerConfig>>>,
|
||||||
|
Path(response_id): Path<String>,
|
||||||
|
) -> Response {
|
||||||
|
let config = config.read().await;
|
||||||
|
if should_fail(&config).await {
|
||||||
|
return (
|
||||||
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
|
Json(json!({ "error": "Random failure for testing" })),
|
||||||
|
)
|
||||||
|
.into_response();
|
||||||
|
}
|
||||||
|
let timestamp = SystemTime::now()
|
||||||
|
.duration_since(UNIX_EPOCH)
|
||||||
|
.unwrap()
|
||||||
|
.as_secs() as i64;
|
||||||
|
if response_exists_for_port(config.port, &response_id) {
|
||||||
|
Json(json!({
|
||||||
|
"id": response_id,
|
||||||
|
"object": "response",
|
||||||
|
"created_at": timestamp,
|
||||||
|
"model": "mock-model",
|
||||||
|
"output": [],
|
||||||
|
"status": "cancelled",
|
||||||
|
"usage": null
|
||||||
|
}))
|
||||||
|
.into_response()
|
||||||
|
} else {
|
||||||
|
StatusCode::NOT_FOUND.into_response()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Simple in-memory response store per worker port (for tests) ---
|
||||||
|
static RESP_STORE: OnceLock<Mutex<HashMap<u16, HashSet<String>>>> = OnceLock::new();
|
||||||
|
|
||||||
|
fn get_store() -> &'static Mutex<HashMap<u16, HashSet<String>>> {
|
||||||
|
RESP_STORE.get_or_init(|| Mutex::new(HashMap::new()))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn store_response_for_port(port: u16, response_id: &str) {
|
||||||
|
let mut map = get_store().lock().unwrap();
|
||||||
|
map.entry(port).or_default().insert(response_id.to_string());
|
||||||
|
}
|
||||||
|
|
||||||
|
fn response_exists_for_port(port: u16, response_id: &str) -> bool {
|
||||||
|
let map = get_store().lock().unwrap();
|
||||||
|
map.get(&port)
|
||||||
|
.map(|set| set.contains(response_id))
|
||||||
|
.unwrap_or(false)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Minimal rerank handler returning mock results; router shapes final response
|
||||||
async fn rerank_handler(
|
async fn rerank_handler(
|
||||||
State(config): State<Arc<RwLock<MockWorkerConfig>>>,
|
State(config): State<Arc<RwLock<MockWorkerConfig>>>,
|
||||||
Json(payload): Json<serde_json::Value>,
|
Json(payload): Json<serde_json::Value>,
|
||||||
|
|||||||
Reference in New Issue
Block a user