[router] Add get and cancel method for response api (#10387)
This commit is contained in:
@@ -2,7 +2,7 @@
|
||||
#![allow(dead_code)]
|
||||
|
||||
use axum::{
|
||||
extract::{Json, State},
|
||||
extract::{Json, Path, State},
|
||||
http::StatusCode,
|
||||
response::sse::{Event, KeepAlive},
|
||||
response::{IntoResponse, Response, Sse},
|
||||
@@ -11,8 +11,9 @@ use axum::{
|
||||
};
|
||||
use futures_util::stream::{self, StreamExt};
|
||||
use serde_json::json;
|
||||
use std::collections::{HashMap, HashSet};
|
||||
use std::convert::Infallible;
|
||||
use std::sync::Arc;
|
||||
use std::sync::{Arc, Mutex, OnceLock};
|
||||
use std::time::{SystemTime, UNIX_EPOCH};
|
||||
use tokio::sync::RwLock;
|
||||
use uuid::Uuid;
|
||||
@@ -83,6 +84,11 @@ impl MockWorker {
|
||||
.route("/v1/completions", post(completions_handler))
|
||||
.route("/v1/rerank", post(rerank_handler))
|
||||
.route("/v1/responses", post(responses_handler))
|
||||
.route("/v1/responses/{response_id}", get(responses_get_handler))
|
||||
.route(
|
||||
"/v1/responses/{response_id}/cancel",
|
||||
post(responses_cancel_handler),
|
||||
)
|
||||
.route("/flush_cache", post(flush_cache_handler))
|
||||
.route("/v1/models", get(v1_models_handler))
|
||||
.with_state(config);
|
||||
@@ -584,6 +590,21 @@ async fn responses_handler(
|
||||
.unwrap()
|
||||
.as_secs() as i64;
|
||||
|
||||
// Background storage simulation
|
||||
let is_background = payload
|
||||
.get("background")
|
||||
.and_then(|v| v.as_bool())
|
||||
.unwrap_or(false);
|
||||
let req_id = payload
|
||||
.get("request_id")
|
||||
.and_then(|v| v.as_str())
|
||||
.map(|s| s.to_string());
|
||||
if is_background {
|
||||
if let Some(id) = &req_id {
|
||||
store_response_for_port(config.port, id);
|
||||
}
|
||||
}
|
||||
|
||||
if is_stream {
|
||||
let request_id = format!("resp-{}", Uuid::new_v4());
|
||||
|
||||
@@ -610,6 +631,18 @@ async fn responses_handler(
|
||||
Sse::new(stream)
|
||||
.keep_alive(KeepAlive::default())
|
||||
.into_response()
|
||||
} else if is_background {
|
||||
let rid = req_id.unwrap_or_else(|| format!("resp-{}", Uuid::new_v4()));
|
||||
Json(json!({
|
||||
"id": rid,
|
||||
"object": "response",
|
||||
"created_at": timestamp,
|
||||
"model": "mock-model",
|
||||
"output": [],
|
||||
"status": "queued",
|
||||
"usage": null
|
||||
}))
|
||||
.into_response()
|
||||
} else {
|
||||
Json(json!({
|
||||
"id": format!("resp-{}", Uuid::new_v4()),
|
||||
@@ -688,6 +721,95 @@ async fn v1_models_handler(State(config): State<Arc<RwLock<MockWorkerConfig>>>)
|
||||
.into_response()
|
||||
}
|
||||
|
||||
async fn responses_get_handler(
|
||||
State(config): State<Arc<RwLock<MockWorkerConfig>>>,
|
||||
Path(response_id): Path<String>,
|
||||
) -> Response {
|
||||
let config = config.read().await;
|
||||
if should_fail(&config).await {
|
||||
return (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({ "error": "Random failure for testing" })),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
let timestamp = SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.unwrap()
|
||||
.as_secs() as i64;
|
||||
// Only return 200 if this worker "stores" the response id
|
||||
if response_exists_for_port(config.port, &response_id) {
|
||||
Json(json!({
|
||||
"id": response_id,
|
||||
"object": "response",
|
||||
"created_at": timestamp,
|
||||
"model": "mock-model",
|
||||
"output": [],
|
||||
"status": "completed",
|
||||
"usage": {
|
||||
"input_tokens": 0,
|
||||
"output_tokens": 0,
|
||||
"total_tokens": 0
|
||||
}
|
||||
}))
|
||||
.into_response()
|
||||
} else {
|
||||
StatusCode::NOT_FOUND.into_response()
|
||||
}
|
||||
}
|
||||
|
||||
async fn responses_cancel_handler(
|
||||
State(config): State<Arc<RwLock<MockWorkerConfig>>>,
|
||||
Path(response_id): Path<String>,
|
||||
) -> Response {
|
||||
let config = config.read().await;
|
||||
if should_fail(&config).await {
|
||||
return (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({ "error": "Random failure for testing" })),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
let timestamp = SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.unwrap()
|
||||
.as_secs() as i64;
|
||||
if response_exists_for_port(config.port, &response_id) {
|
||||
Json(json!({
|
||||
"id": response_id,
|
||||
"object": "response",
|
||||
"created_at": timestamp,
|
||||
"model": "mock-model",
|
||||
"output": [],
|
||||
"status": "cancelled",
|
||||
"usage": null
|
||||
}))
|
||||
.into_response()
|
||||
} else {
|
||||
StatusCode::NOT_FOUND.into_response()
|
||||
}
|
||||
}
|
||||
|
||||
// --- Simple in-memory response store per worker port (for tests) ---
|
||||
static RESP_STORE: OnceLock<Mutex<HashMap<u16, HashSet<String>>>> = OnceLock::new();
|
||||
|
||||
fn get_store() -> &'static Mutex<HashMap<u16, HashSet<String>>> {
|
||||
RESP_STORE.get_or_init(|| Mutex::new(HashMap::new()))
|
||||
}
|
||||
|
||||
fn store_response_for_port(port: u16, response_id: &str) {
|
||||
let mut map = get_store().lock().unwrap();
|
||||
map.entry(port).or_default().insert(response_id.to_string());
|
||||
}
|
||||
|
||||
fn response_exists_for_port(port: u16, response_id: &str) -> bool {
|
||||
let map = get_store().lock().unwrap();
|
||||
map.get(&port)
|
||||
.map(|set| set.contains(response_id))
|
||||
.unwrap_or(false)
|
||||
}
|
||||
|
||||
// Minimal rerank handler returning mock results; router shapes final response
|
||||
async fn rerank_handler(
|
||||
State(config): State<Arc<RwLock<MockWorkerConfig>>>,
|
||||
Json(payload): Json<serde_json::Value>,
|
||||
|
||||
Reference in New Issue
Block a user