[router] Add get and cancel method for response api (#10387)

This commit is contained in:
Keyang Ru
2025-09-12 16:19:38 -07:00
committed by GitHub
parent 2f173ea074
commit 366043db8e
9 changed files with 581 additions and 5 deletions

View File

@@ -994,6 +994,7 @@ mod router_policy_tests {
#[cfg(test)]
mod responses_endpoint_tests {
use super::*;
use reqwest::Client as HttpClient;
#[tokio::test]
async fn test_v1_responses_non_streaming() {
@@ -1074,6 +1075,207 @@ mod responses_endpoint_tests {
// We don't fully consume the stream in this test harness.
ctx.shutdown().await;
}
#[tokio::test]
async fn test_v1_responses_get() {
let ctx = TestContext::new(vec![MockWorkerConfig {
port: 18952,
worker_type: WorkerType::Regular,
health_status: HealthStatus::Healthy,
response_delay_ms: 0,
fail_rate: 0.0,
}])
.await;
let app = ctx.create_app().await;
// First create a response to obtain an id
let payload = json!({
"input": "Hello Responses API",
"model": "mock-model",
"stream": false
});
let req = Request::builder()
.method("POST")
.uri("/v1/responses")
.header(CONTENT_TYPE, "application/json")
.body(Body::from(serde_json::to_string(&payload).unwrap()))
.unwrap();
let resp = app.clone().oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
let body = axum::body::to_bytes(resp.into_body(), usize::MAX)
.await
.unwrap();
let body_json: serde_json::Value = serde_json::from_slice(&body).unwrap();
let resp_id = body_json["id"].as_str().unwrap().to_string();
// Retrieve the response
let req = Request::builder()
.method("GET")
.uri(format!("/v1/responses/{}", resp_id))
.body(Body::empty())
.unwrap();
let resp = app.clone().oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
let body = axum::body::to_bytes(resp.into_body(), usize::MAX)
.await
.unwrap();
let get_json: serde_json::Value = serde_json::from_slice(&body).unwrap();
assert_eq!(get_json["object"], "response");
ctx.shutdown().await;
}
#[tokio::test]
async fn test_v1_responses_cancel() {
let ctx = TestContext::new(vec![MockWorkerConfig {
port: 18953,
worker_type: WorkerType::Regular,
health_status: HealthStatus::Healthy,
response_delay_ms: 0,
fail_rate: 0.0,
}])
.await;
let app = ctx.create_app().await;
// First create a response to obtain an id
let payload = json!({
"input": "Hello Responses API",
"model": "mock-model",
"stream": false
});
let req = Request::builder()
.method("POST")
.uri("/v1/responses")
.header(CONTENT_TYPE, "application/json")
.body(Body::from(serde_json::to_string(&payload).unwrap()))
.unwrap();
let resp = app.clone().oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
let body = axum::body::to_bytes(resp.into_body(), usize::MAX)
.await
.unwrap();
let body_json: serde_json::Value = serde_json::from_slice(&body).unwrap();
let resp_id = body_json["id"].as_str().unwrap().to_string();
// Cancel the response
let req = Request::builder()
.method("POST")
.uri(format!("/v1/responses/{}/cancel", resp_id))
.body(Body::empty())
.unwrap();
let resp = app.clone().oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
let body = axum::body::to_bytes(resp.into_body(), usize::MAX)
.await
.unwrap();
let cancel_json: serde_json::Value = serde_json::from_slice(&body).unwrap();
assert_eq!(cancel_json["status"], "cancelled");
ctx.shutdown().await;
}
#[tokio::test]
async fn test_v1_responses_delete_and_list_not_implemented() {
let ctx = TestContext::new(vec![MockWorkerConfig {
port: 18954,
worker_type: WorkerType::Regular,
health_status: HealthStatus::Healthy,
response_delay_ms: 0,
fail_rate: 0.0,
}])
.await;
let app = ctx.create_app().await;
// Use an arbitrary id for delete/list
let resp_id = "resp-test-123";
let req = Request::builder()
.method("DELETE")
.uri(format!("/v1/responses/{}", resp_id))
.body(Body::empty())
.unwrap();
let resp = app.clone().oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::NOT_IMPLEMENTED);
let req = Request::builder()
.method("GET")
.uri(format!("/v1/responses/{}/input", resp_id))
.body(Body::empty())
.unwrap();
let resp = app.clone().oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::NOT_IMPLEMENTED);
ctx.shutdown().await;
}
#[tokio::test]
async fn test_v1_responses_get_multi_worker_fanout() {
// Start two mock workers
let ctx = TestContext::new(vec![
MockWorkerConfig {
port: 18960,
worker_type: WorkerType::Regular,
health_status: HealthStatus::Healthy,
response_delay_ms: 0,
fail_rate: 0.0,
},
MockWorkerConfig {
port: 18961,
worker_type: WorkerType::Regular,
health_status: HealthStatus::Healthy,
response_delay_ms: 0,
fail_rate: 0.0,
},
])
.await;
let app = ctx.create_app().await;
// Create a background response with a known id
let rid = format!("resp_{}", 18960); // arbitrary unique id
let payload = json!({
"input": "Hello Responses API",
"model": "mock-model",
"background": true,
"store": true,
"request_id": rid,
});
let req = Request::builder()
.method("POST")
.uri("/v1/responses")
.header(CONTENT_TYPE, "application/json")
.body(Body::from(serde_json::to_string(&payload).unwrap()))
.unwrap();
let resp = app.clone().oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
// Using the router, GET should succeed by fanning out across workers
let req = Request::builder()
.method("GET")
.uri(format!("/v1/responses/{}", rid))
.body(Body::empty())
.unwrap();
let resp = app.clone().oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
// Validate only one worker holds the metadata: direct calls
let client = HttpClient::new();
let mut ok_count = 0usize;
for url in ctx.router.get_worker_urls() {
let get_url = format!("{}/v1/responses/{}", url, rid);
let res = client.get(get_url).send().await.unwrap();
if res.status() == StatusCode::OK {
ok_count += 1;
}
}
assert_eq!(ok_count, 1, "exactly one worker should store the response");
ctx.shutdown().await;
}
}
#[cfg(test)]

View File

@@ -2,7 +2,7 @@
#![allow(dead_code)]
use axum::{
extract::{Json, State},
extract::{Json, Path, State},
http::StatusCode,
response::sse::{Event, KeepAlive},
response::{IntoResponse, Response, Sse},
@@ -11,8 +11,9 @@ use axum::{
};
use futures_util::stream::{self, StreamExt};
use serde_json::json;
use std::collections::{HashMap, HashSet};
use std::convert::Infallible;
use std::sync::Arc;
use std::sync::{Arc, Mutex, OnceLock};
use std::time::{SystemTime, UNIX_EPOCH};
use tokio::sync::RwLock;
use uuid::Uuid;
@@ -83,6 +84,11 @@ impl MockWorker {
.route("/v1/completions", post(completions_handler))
.route("/v1/rerank", post(rerank_handler))
.route("/v1/responses", post(responses_handler))
.route("/v1/responses/{response_id}", get(responses_get_handler))
.route(
"/v1/responses/{response_id}/cancel",
post(responses_cancel_handler),
)
.route("/flush_cache", post(flush_cache_handler))
.route("/v1/models", get(v1_models_handler))
.with_state(config);
@@ -584,6 +590,21 @@ async fn responses_handler(
.unwrap()
.as_secs() as i64;
// Background storage simulation
let is_background = payload
.get("background")
.and_then(|v| v.as_bool())
.unwrap_or(false);
let req_id = payload
.get("request_id")
.and_then(|v| v.as_str())
.map(|s| s.to_string());
if is_background {
if let Some(id) = &req_id {
store_response_for_port(config.port, id);
}
}
if is_stream {
let request_id = format!("resp-{}", Uuid::new_v4());
@@ -610,6 +631,18 @@ async fn responses_handler(
Sse::new(stream)
.keep_alive(KeepAlive::default())
.into_response()
} else if is_background {
let rid = req_id.unwrap_or_else(|| format!("resp-{}", Uuid::new_v4()));
Json(json!({
"id": rid,
"object": "response",
"created_at": timestamp,
"model": "mock-model",
"output": [],
"status": "queued",
"usage": null
}))
.into_response()
} else {
Json(json!({
"id": format!("resp-{}", Uuid::new_v4()),
@@ -688,6 +721,95 @@ async fn v1_models_handler(State(config): State<Arc<RwLock<MockWorkerConfig>>>)
.into_response()
}
async fn responses_get_handler(
State(config): State<Arc<RwLock<MockWorkerConfig>>>,
Path(response_id): Path<String>,
) -> Response {
let config = config.read().await;
if should_fail(&config).await {
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": "Random failure for testing" })),
)
.into_response();
}
let timestamp = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs() as i64;
// Only return 200 if this worker "stores" the response id
if response_exists_for_port(config.port, &response_id) {
Json(json!({
"id": response_id,
"object": "response",
"created_at": timestamp,
"model": "mock-model",
"output": [],
"status": "completed",
"usage": {
"input_tokens": 0,
"output_tokens": 0,
"total_tokens": 0
}
}))
.into_response()
} else {
StatusCode::NOT_FOUND.into_response()
}
}
async fn responses_cancel_handler(
State(config): State<Arc<RwLock<MockWorkerConfig>>>,
Path(response_id): Path<String>,
) -> Response {
let config = config.read().await;
if should_fail(&config).await {
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": "Random failure for testing" })),
)
.into_response();
}
let timestamp = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs() as i64;
if response_exists_for_port(config.port, &response_id) {
Json(json!({
"id": response_id,
"object": "response",
"created_at": timestamp,
"model": "mock-model",
"output": [],
"status": "cancelled",
"usage": null
}))
.into_response()
} else {
StatusCode::NOT_FOUND.into_response()
}
}
// --- Simple in-memory response store per worker port (for tests) ---
static RESP_STORE: OnceLock<Mutex<HashMap<u16, HashSet<String>>>> = OnceLock::new();
fn get_store() -> &'static Mutex<HashMap<u16, HashSet<String>>> {
RESP_STORE.get_or_init(|| Mutex::new(HashMap::new()))
}
fn store_response_for_port(port: u16, response_id: &str) {
let mut map = get_store().lock().unwrap();
map.entry(port).or_default().insert(response_id.to_string());
}
fn response_exists_for_port(port: u16, response_id: &str) -> bool {
let map = get_store().lock().unwrap();
map.get(&port)
.map(|set| set.contains(response_id))
.unwrap_or(false)
}
// Minimal rerank handler returning mock results; router shapes final response
async fn rerank_handler(
State(config): State<Arc<RwLock<MockWorkerConfig>>>,
Json(payload): Json<serde_json::Value>,