diff --git a/sgl-router/src/routers/grpc/pd_router.rs b/sgl-router/src/routers/grpc/pd_router.rs index 3efb9ca87..89254c831 100644 --- a/sgl-router/src/routers/grpc/pd_router.rs +++ b/sgl-router/src/routers/grpc/pd_router.rs @@ -301,6 +301,14 @@ impl RouterTrait for GrpcPDRouter { (StatusCode::NOT_IMPLEMENTED).into_response() } + async fn get_response(&self, _headers: Option<&HeaderMap>, _response_id: &str) -> Response { + (StatusCode::NOT_IMPLEMENTED).into_response() + } + + async fn cancel_response(&self, _headers: Option<&HeaderMap>, _response_id: &str) -> Response { + (StatusCode::NOT_IMPLEMENTED).into_response() + } + async fn route_embeddings(&self, _headers: Option<&HeaderMap>, _body: Body) -> Response { (StatusCode::NOT_IMPLEMENTED).into_response() } diff --git a/sgl-router/src/routers/grpc/router.rs b/sgl-router/src/routers/grpc/router.rs index cb4bab412..f631c1358 100644 --- a/sgl-router/src/routers/grpc/router.rs +++ b/sgl-router/src/routers/grpc/router.rs @@ -234,6 +234,14 @@ impl RouterTrait for GrpcRouter { (StatusCode::NOT_IMPLEMENTED).into_response() } + async fn get_response(&self, _headers: Option<&HeaderMap>, _response_id: &str) -> Response { + (StatusCode::NOT_IMPLEMENTED).into_response() + } + + async fn cancel_response(&self, _headers: Option<&HeaderMap>, _response_id: &str) -> Response { + (StatusCode::NOT_IMPLEMENTED).into_response() + } + async fn route_embeddings(&self, _headers: Option<&HeaderMap>, _body: Body) -> Response { (StatusCode::NOT_IMPLEMENTED).into_response() } diff --git a/sgl-router/src/routers/http/openai_router.rs b/sgl-router/src/routers/http/openai_router.rs index e75cb794a..bfe5a1d7b 100644 --- a/sgl-router/src/routers/http/openai_router.rs +++ b/sgl-router/src/routers/http/openai_router.rs @@ -351,6 +351,22 @@ impl super::super::RouterTrait for OpenAIRouter { .into_response() } + async fn get_response(&self, _headers: Option<&HeaderMap>, _response_id: &str) -> Response { + ( + StatusCode::NOT_IMPLEMENTED, + "Responses retrieve endpoint not implemented for OpenAI router", + ) + .into_response() + } + + async fn cancel_response(&self, _headers: Option<&HeaderMap>, _response_id: &str) -> Response { + ( + StatusCode::NOT_IMPLEMENTED, + "Responses cancel endpoint not implemented for OpenAI router", + ) + .into_response() + } + async fn flush_cache(&self) -> Response { ( StatusCode::NOT_IMPLEMENTED, diff --git a/sgl-router/src/routers/http/pd_router.rs b/sgl-router/src/routers/http/pd_router.rs index 4f31cc225..7b23c298f 100644 --- a/sgl-router/src/routers/http/pd_router.rs +++ b/sgl-router/src/routers/http/pd_router.rs @@ -1922,6 +1922,22 @@ impl RouterTrait for PDRouter { .into_response() } + async fn get_response(&self, _headers: Option<&HeaderMap>, _response_id: &str) -> Response { + ( + StatusCode::NOT_IMPLEMENTED, + "Responses retrieve endpoint not implemented for PD router", + ) + .into_response() + } + + async fn cancel_response(&self, _headers: Option<&HeaderMap>, _response_id: &str) -> Response { + ( + StatusCode::NOT_IMPLEMENTED, + "Responses cancel endpoint not implemented for PD router", + ) + .into_response() + } + async fn route_embeddings(&self, _headers: Option<&HeaderMap>, _body: Body) -> Response { todo!() } diff --git a/sgl-router/src/routers/http/router.rs b/sgl-router/src/routers/http/router.rs index 8b928ea37..6c1eb2554 100644 --- a/sgl-router/src/routers/http/router.rs +++ b/sgl-router/src/routers/http/router.rs @@ -15,7 +15,9 @@ use axum::body::to_bytes; use axum::{ body::Body, extract::Request, - http::{header::CONTENT_LENGTH, header::CONTENT_TYPE, HeaderMap, HeaderValue, StatusCode}, + http::{ + header::CONTENT_LENGTH, header::CONTENT_TYPE, HeaderMap, HeaderValue, Method, StatusCode, + }, response::{IntoResponse, Response}, Json, }; @@ -600,6 +602,114 @@ impl Router { response } + // Helper: return base worker URL (strips DP suffix when enabled) + fn worker_base_url(&self, worker_url: &str) -> String { + if self.dp_aware { + if let Ok((prefix, _)) = Self::extract_dp_rank(worker_url) { + return prefix.to_string(); + } + } + worker_url.to_string() + } + + // Generic simple routing for GET/POST without JSON body + async fn route_simple_request( + &self, + headers: Option<&HeaderMap>, + endpoint: &str, + method: Method, + ) -> Response { + // TODO: currently the sglang worker is using in-memory state management, so this implementation has to fan out to all workers. + // Eventually, we need to have router to manage the chat history with a proper database, will update this implementation accordingly. + let worker_urls = self.get_worker_urls(); + if worker_urls.is_empty() { + return (StatusCode::SERVICE_UNAVAILABLE, "No available workers").into_response(); + } + + let mut last_response: Option = None; + for worker_url in worker_urls { + let base = self.worker_base_url(&worker_url); + + let url = format!("{}/{}", base, endpoint); + let mut request_builder = match method { + Method::GET => self.client.get(url), + Method::POST => self.client.post(url), + _ => { + return ( + StatusCode::METHOD_NOT_ALLOWED, + "Unsupported method for simple routing", + ) + .into_response() + } + }; + + if let Some(hdrs) = headers { + for (name, value) in hdrs { + let name_lc = name.as_str().to_lowercase(); + if name_lc != "content-type" && name_lc != "content-length" { + request_builder = request_builder.header(name, value); + } + } + } + + match request_builder.send().await { + Ok(res) => { + let status = StatusCode::from_u16(res.status().as_u16()) + .unwrap_or(StatusCode::INTERNAL_SERVER_ERROR); + let response_headers = header_utils::preserve_response_headers(res.headers()); + match res.bytes().await { + Ok(body) => { + let mut response = Response::new(axum::body::Body::from(body)); + *response.status_mut() = status; + *response.headers_mut() = response_headers; + if status.is_success() { + return response; + } + last_response = Some(response); + } + Err(e) => { + last_response = Some( + ( + StatusCode::INTERNAL_SERVER_ERROR, + format!("Failed to read response: {}", e), + ) + .into_response(), + ); + } + } + } + Err(e) => { + last_response = Some( + ( + StatusCode::INTERNAL_SERVER_ERROR, + format!("Request failed: {}", e), + ) + .into_response(), + ); + } + } + } + + last_response + .unwrap_or_else(|| (StatusCode::BAD_GATEWAY, "No worker response").into_response()) + } + + // Route a GET request with provided headers to a specific endpoint + async fn route_get_request(&self, headers: Option<&HeaderMap>, endpoint: &str) -> Response { + self.route_simple_request(headers, endpoint, Method::GET) + .await + } + + // Route a POST request with empty body to a specific endpoint + async fn route_post_empty_request( + &self, + headers: Option<&HeaderMap>, + endpoint: &str, + ) -> Response { + self.route_simple_request(headers, endpoint, Method::POST) + .await + } + // TODO (rui): Better accommodate to the Worker abstraction fn extract_dp_rank(worker_url: &str) -> Result<(&str, usize), String> { let parts: Vec<&str> = worker_url.split('@').collect(); @@ -1310,6 +1420,16 @@ impl RouterTrait for Router { .await } + async fn get_response(&self, headers: Option<&HeaderMap>, response_id: &str) -> Response { + let endpoint = format!("v1/responses/{}", response_id); + self.route_get_request(headers, &endpoint).await + } + + async fn cancel_response(&self, headers: Option<&HeaderMap>, response_id: &str) -> Response { + let endpoint = format!("v1/responses/{}/cancel", response_id); + self.route_post_empty_request(headers, &endpoint).await + } + async fn route_embeddings(&self, _headers: Option<&HeaderMap>, _body: Body) -> Response { todo!() } diff --git a/sgl-router/src/routers/mod.rs b/sgl-router/src/routers/mod.rs index fba121002..d47951a40 100644 --- a/sgl-router/src/routers/mod.rs +++ b/sgl-router/src/routers/mod.rs @@ -95,6 +95,34 @@ pub trait RouterTrait: Send + Sync + Debug + WorkerManagement { model_id: Option<&str>, ) -> Response; + /// Retrieve a stored/background response by id + async fn get_response(&self, headers: Option<&HeaderMap>, response_id: &str) -> Response; + + /// Cancel a background response by id + async fn cancel_response(&self, headers: Option<&HeaderMap>, response_id: &str) -> Response; + + /// Delete a response by id + async fn delete_response(&self, _headers: Option<&HeaderMap>, _response_id: &str) -> Response { + ( + StatusCode::NOT_IMPLEMENTED, + "Responses delete endpoint not implemented", + ) + .into_response() + } + + /// List input items of a response by id + async fn list_response_input_items( + &self, + _headers: Option<&HeaderMap>, + _response_id: &str, + ) -> Response { + ( + StatusCode::NOT_IMPLEMENTED, + "Responses list input items endpoint not implemented", + ) + .into_response() + } + async fn route_embeddings(&self, headers: Option<&HeaderMap>, body: Body) -> Response; async fn route_rerank( diff --git a/sgl-router/src/server.rs b/sgl-router/src/server.rs index acaf6a19d..b6eeafafe 100644 --- a/sgl-router/src/server.rs +++ b/sgl-router/src/server.rs @@ -16,10 +16,10 @@ use crate::service_discovery::{start_service_discovery, ServiceDiscoveryConfig}; use crate::tokenizer::{factory as tokenizer_factory, traits::Tokenizer}; use crate::tool_parser::ParserRegistry; use axum::{ - extract::{Query, Request, State}, + extract::{Path, Query, Request, State}, http::StatusCode, response::{IntoResponse, Response}, - routing::{get, post}, + routing::{delete, get, post}, Json, Router, }; use reqwest::Client; @@ -208,6 +208,52 @@ async fn v1_responses( .await } +async fn v1_responses_get( + State(state): State>, + Path(response_id): Path, + headers: http::HeaderMap, +) -> Response { + state + .router + .get_response(Some(&headers), &response_id) + .await +} + +async fn v1_responses_cancel( + State(state): State>, + Path(response_id): Path, + headers: http::HeaderMap, +) -> Response { + state + .router + .cancel_response(Some(&headers), &response_id) + .await +} + +async fn v1_responses_delete( + State(state): State>, + Path(response_id): Path, + headers: http::HeaderMap, +) -> Response { + // Python server does not support this yet + state + .router + .delete_response(Some(&headers), &response_id) + .await +} + +async fn v1_responses_list_input_items( + State(state): State>, + Path(response_id): Path, + headers: http::HeaderMap, +) -> Response { + // Python server does not support this yet + state + .router + .list_response_input_items(Some(&headers), &response_id) + .await +} + // Worker management endpoints async fn add_worker( State(state): State>, @@ -419,6 +465,16 @@ pub fn build_app( .route("/rerank", post(rerank)) .route("/v1/rerank", post(v1_rerank)) .route("/v1/responses", post(v1_responses)) + .route("/v1/responses/{response_id}", get(v1_responses_get)) + .route( + "/v1/responses/{response_id}/cancel", + post(v1_responses_cancel), + ) + .route("/v1/responses/{response_id}", delete(v1_responses_delete)) + .route( + "/v1/responses/{response_id}/input", + get(v1_responses_list_input_items), + ) .route_layer(axum::middleware::from_fn_with_state( app_state.clone(), crate::middleware::concurrency_limit_middleware, diff --git a/sgl-router/tests/api_endpoints_test.rs b/sgl-router/tests/api_endpoints_test.rs index 39911a20d..b7a346338 100644 --- a/sgl-router/tests/api_endpoints_test.rs +++ b/sgl-router/tests/api_endpoints_test.rs @@ -994,6 +994,7 @@ mod router_policy_tests { #[cfg(test)] mod responses_endpoint_tests { use super::*; + use reqwest::Client as HttpClient; #[tokio::test] async fn test_v1_responses_non_streaming() { @@ -1074,6 +1075,207 @@ mod responses_endpoint_tests { // We don't fully consume the stream in this test harness. ctx.shutdown().await; } + + #[tokio::test] + async fn test_v1_responses_get() { + let ctx = TestContext::new(vec![MockWorkerConfig { + port: 18952, + worker_type: WorkerType::Regular, + health_status: HealthStatus::Healthy, + response_delay_ms: 0, + fail_rate: 0.0, + }]) + .await; + + let app = ctx.create_app().await; + + // First create a response to obtain an id + let payload = json!({ + "input": "Hello Responses API", + "model": "mock-model", + "stream": false + }); + let req = Request::builder() + .method("POST") + .uri("/v1/responses") + .header(CONTENT_TYPE, "application/json") + .body(Body::from(serde_json::to_string(&payload).unwrap())) + .unwrap(); + let resp = app.clone().oneshot(req).await.unwrap(); + assert_eq!(resp.status(), StatusCode::OK); + let body = axum::body::to_bytes(resp.into_body(), usize::MAX) + .await + .unwrap(); + let body_json: serde_json::Value = serde_json::from_slice(&body).unwrap(); + let resp_id = body_json["id"].as_str().unwrap().to_string(); + + // Retrieve the response + let req = Request::builder() + .method("GET") + .uri(format!("/v1/responses/{}", resp_id)) + .body(Body::empty()) + .unwrap(); + let resp = app.clone().oneshot(req).await.unwrap(); + assert_eq!(resp.status(), StatusCode::OK); + let body = axum::body::to_bytes(resp.into_body(), usize::MAX) + .await + .unwrap(); + let get_json: serde_json::Value = serde_json::from_slice(&body).unwrap(); + assert_eq!(get_json["object"], "response"); + + ctx.shutdown().await; + } + + #[tokio::test] + async fn test_v1_responses_cancel() { + let ctx = TestContext::new(vec![MockWorkerConfig { + port: 18953, + worker_type: WorkerType::Regular, + health_status: HealthStatus::Healthy, + response_delay_ms: 0, + fail_rate: 0.0, + }]) + .await; + + let app = ctx.create_app().await; + + // First create a response to obtain an id + let payload = json!({ + "input": "Hello Responses API", + "model": "mock-model", + "stream": false + }); + let req = Request::builder() + .method("POST") + .uri("/v1/responses") + .header(CONTENT_TYPE, "application/json") + .body(Body::from(serde_json::to_string(&payload).unwrap())) + .unwrap(); + let resp = app.clone().oneshot(req).await.unwrap(); + assert_eq!(resp.status(), StatusCode::OK); + let body = axum::body::to_bytes(resp.into_body(), usize::MAX) + .await + .unwrap(); + let body_json: serde_json::Value = serde_json::from_slice(&body).unwrap(); + let resp_id = body_json["id"].as_str().unwrap().to_string(); + + // Cancel the response + let req = Request::builder() + .method("POST") + .uri(format!("/v1/responses/{}/cancel", resp_id)) + .body(Body::empty()) + .unwrap(); + let resp = app.clone().oneshot(req).await.unwrap(); + assert_eq!(resp.status(), StatusCode::OK); + let body = axum::body::to_bytes(resp.into_body(), usize::MAX) + .await + .unwrap(); + let cancel_json: serde_json::Value = serde_json::from_slice(&body).unwrap(); + assert_eq!(cancel_json["status"], "cancelled"); + + ctx.shutdown().await; + } + + #[tokio::test] + async fn test_v1_responses_delete_and_list_not_implemented() { + let ctx = TestContext::new(vec![MockWorkerConfig { + port: 18954, + worker_type: WorkerType::Regular, + health_status: HealthStatus::Healthy, + response_delay_ms: 0, + fail_rate: 0.0, + }]) + .await; + + let app = ctx.create_app().await; + + // Use an arbitrary id for delete/list + let resp_id = "resp-test-123"; + + let req = Request::builder() + .method("DELETE") + .uri(format!("/v1/responses/{}", resp_id)) + .body(Body::empty()) + .unwrap(); + let resp = app.clone().oneshot(req).await.unwrap(); + assert_eq!(resp.status(), StatusCode::NOT_IMPLEMENTED); + + let req = Request::builder() + .method("GET") + .uri(format!("/v1/responses/{}/input", resp_id)) + .body(Body::empty()) + .unwrap(); + let resp = app.clone().oneshot(req).await.unwrap(); + assert_eq!(resp.status(), StatusCode::NOT_IMPLEMENTED); + + ctx.shutdown().await; + } + + #[tokio::test] + async fn test_v1_responses_get_multi_worker_fanout() { + // Start two mock workers + let ctx = TestContext::new(vec![ + MockWorkerConfig { + port: 18960, + worker_type: WorkerType::Regular, + health_status: HealthStatus::Healthy, + response_delay_ms: 0, + fail_rate: 0.0, + }, + MockWorkerConfig { + port: 18961, + worker_type: WorkerType::Regular, + health_status: HealthStatus::Healthy, + response_delay_ms: 0, + fail_rate: 0.0, + }, + ]) + .await; + + let app = ctx.create_app().await; + + // Create a background response with a known id + let rid = format!("resp_{}", 18960); // arbitrary unique id + let payload = json!({ + "input": "Hello Responses API", + "model": "mock-model", + "background": true, + "store": true, + "request_id": rid, + }); + + let req = Request::builder() + .method("POST") + .uri("/v1/responses") + .header(CONTENT_TYPE, "application/json") + .body(Body::from(serde_json::to_string(&payload).unwrap())) + .unwrap(); + let resp = app.clone().oneshot(req).await.unwrap(); + assert_eq!(resp.status(), StatusCode::OK); + + // Using the router, GET should succeed by fanning out across workers + let req = Request::builder() + .method("GET") + .uri(format!("/v1/responses/{}", rid)) + .body(Body::empty()) + .unwrap(); + let resp = app.clone().oneshot(req).await.unwrap(); + assert_eq!(resp.status(), StatusCode::OK); + + // Validate only one worker holds the metadata: direct calls + let client = HttpClient::new(); + let mut ok_count = 0usize; + for url in ctx.router.get_worker_urls() { + let get_url = format!("{}/v1/responses/{}", url, rid); + let res = client.get(get_url).send().await.unwrap(); + if res.status() == StatusCode::OK { + ok_count += 1; + } + } + assert_eq!(ok_count, 1, "exactly one worker should store the response"); + + ctx.shutdown().await; + } } #[cfg(test)] diff --git a/sgl-router/tests/common/mock_worker.rs b/sgl-router/tests/common/mock_worker.rs index daf2a4889..f46e6b854 100755 --- a/sgl-router/tests/common/mock_worker.rs +++ b/sgl-router/tests/common/mock_worker.rs @@ -2,7 +2,7 @@ #![allow(dead_code)] use axum::{ - extract::{Json, State}, + extract::{Json, Path, State}, http::StatusCode, response::sse::{Event, KeepAlive}, response::{IntoResponse, Response, Sse}, @@ -11,8 +11,9 @@ use axum::{ }; use futures_util::stream::{self, StreamExt}; use serde_json::json; +use std::collections::{HashMap, HashSet}; use std::convert::Infallible; -use std::sync::Arc; +use std::sync::{Arc, Mutex, OnceLock}; use std::time::{SystemTime, UNIX_EPOCH}; use tokio::sync::RwLock; use uuid::Uuid; @@ -83,6 +84,11 @@ impl MockWorker { .route("/v1/completions", post(completions_handler)) .route("/v1/rerank", post(rerank_handler)) .route("/v1/responses", post(responses_handler)) + .route("/v1/responses/{response_id}", get(responses_get_handler)) + .route( + "/v1/responses/{response_id}/cancel", + post(responses_cancel_handler), + ) .route("/flush_cache", post(flush_cache_handler)) .route("/v1/models", get(v1_models_handler)) .with_state(config); @@ -584,6 +590,21 @@ async fn responses_handler( .unwrap() .as_secs() as i64; + // Background storage simulation + let is_background = payload + .get("background") + .and_then(|v| v.as_bool()) + .unwrap_or(false); + let req_id = payload + .get("request_id") + .and_then(|v| v.as_str()) + .map(|s| s.to_string()); + if is_background { + if let Some(id) = &req_id { + store_response_for_port(config.port, id); + } + } + if is_stream { let request_id = format!("resp-{}", Uuid::new_v4()); @@ -610,6 +631,18 @@ async fn responses_handler( Sse::new(stream) .keep_alive(KeepAlive::default()) .into_response() + } else if is_background { + let rid = req_id.unwrap_or_else(|| format!("resp-{}", Uuid::new_v4())); + Json(json!({ + "id": rid, + "object": "response", + "created_at": timestamp, + "model": "mock-model", + "output": [], + "status": "queued", + "usage": null + })) + .into_response() } else { Json(json!({ "id": format!("resp-{}", Uuid::new_v4()), @@ -688,6 +721,95 @@ async fn v1_models_handler(State(config): State>>) .into_response() } +async fn responses_get_handler( + State(config): State>>, + Path(response_id): Path, +) -> Response { + let config = config.read().await; + if should_fail(&config).await { + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({ "error": "Random failure for testing" })), + ) + .into_response(); + } + let timestamp = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_secs() as i64; + // Only return 200 if this worker "stores" the response id + if response_exists_for_port(config.port, &response_id) { + Json(json!({ + "id": response_id, + "object": "response", + "created_at": timestamp, + "model": "mock-model", + "output": [], + "status": "completed", + "usage": { + "input_tokens": 0, + "output_tokens": 0, + "total_tokens": 0 + } + })) + .into_response() + } else { + StatusCode::NOT_FOUND.into_response() + } +} + +async fn responses_cancel_handler( + State(config): State>>, + Path(response_id): Path, +) -> Response { + let config = config.read().await; + if should_fail(&config).await { + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({ "error": "Random failure for testing" })), + ) + .into_response(); + } + let timestamp = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_secs() as i64; + if response_exists_for_port(config.port, &response_id) { + Json(json!({ + "id": response_id, + "object": "response", + "created_at": timestamp, + "model": "mock-model", + "output": [], + "status": "cancelled", + "usage": null + })) + .into_response() + } else { + StatusCode::NOT_FOUND.into_response() + } +} + +// --- Simple in-memory response store per worker port (for tests) --- +static RESP_STORE: OnceLock>>> = OnceLock::new(); + +fn get_store() -> &'static Mutex>> { + RESP_STORE.get_or_init(|| Mutex::new(HashMap::new())) +} + +fn store_response_for_port(port: u16, response_id: &str) { + let mut map = get_store().lock().unwrap(); + map.entry(port).or_default().insert(response_id.to_string()); +} + +fn response_exists_for_port(port: u16, response_id: &str) -> bool { + let map = get_store().lock().unwrap(); + map.get(&port) + .map(|set| set.contains(response_id)) + .unwrap_or(false) +} + +// Minimal rerank handler returning mock results; router shapes final response async fn rerank_handler( State(config): State>>, Json(payload): Json,