diff --git a/sgl-router/src/routers/grpc/pd_router.rs b/sgl-router/src/routers/grpc/pd_router.rs index d227b460d..8c9645eca 100644 --- a/sgl-router/src/routers/grpc/pd_router.rs +++ b/sgl-router/src/routers/grpc/pd_router.rs @@ -289,6 +289,14 @@ impl RouterTrait for GrpcPDRouter { (StatusCode::NOT_IMPLEMENTED).into_response() } + async fn route_responses( + &self, + _headers: Option<&HeaderMap>, + _body: &crate::protocols::spec::ResponsesRequest, + ) -> Response { + (StatusCode::NOT_IMPLEMENTED).into_response() + } + async fn route_embeddings(&self, _headers: Option<&HeaderMap>, _body: Body) -> Response { (StatusCode::NOT_IMPLEMENTED).into_response() } diff --git a/sgl-router/src/routers/grpc/router.rs b/sgl-router/src/routers/grpc/router.rs index 5c499125f..d42753fc1 100644 --- a/sgl-router/src/routers/grpc/router.rs +++ b/sgl-router/src/routers/grpc/router.rs @@ -222,6 +222,14 @@ impl RouterTrait for GrpcRouter { (StatusCode::NOT_IMPLEMENTED).into_response() } + async fn route_responses( + &self, + _headers: Option<&HeaderMap>, + _body: &crate::protocols::spec::ResponsesRequest, + ) -> Response { + (StatusCode::NOT_IMPLEMENTED).into_response() + } + async fn route_embeddings(&self, _headers: Option<&HeaderMap>, _body: Body) -> Response { (StatusCode::NOT_IMPLEMENTED).into_response() } diff --git a/sgl-router/src/routers/http/openai_router.rs b/sgl-router/src/routers/http/openai_router.rs index 551dd1aa3..b06f20810 100644 --- a/sgl-router/src/routers/http/openai_router.rs +++ b/sgl-router/src/routers/http/openai_router.rs @@ -333,6 +333,18 @@ impl super::super::RouterTrait for OpenAIRouter { .into_response() } + async fn route_responses( + &self, + _headers: Option<&HeaderMap>, + _body: &crate::protocols::spec::ResponsesRequest, + ) -> Response { + ( + StatusCode::NOT_IMPLEMENTED, + "Responses endpoint not implemented for OpenAI router", + ) + .into_response() + } + async fn flush_cache(&self) -> Response { ( StatusCode::NOT_IMPLEMENTED, diff --git a/sgl-router/src/routers/http/pd_router.rs b/sgl-router/src/routers/http/pd_router.rs index 528ead5f5..d66eb8077 100644 --- a/sgl-router/src/routers/http/pd_router.rs +++ b/sgl-router/src/routers/http/pd_router.rs @@ -9,8 +9,8 @@ use crate::core::{ use crate::metrics::RouterMetrics; use crate::policies::LoadBalancingPolicy; use crate::protocols::spec::{ - ChatCompletionRequest, ChatMessage, CompletionRequest, GenerateRequest, StringOrArray, - UserMessageContent, + ChatCompletionRequest, ChatMessage, CompletionRequest, GenerateRequest, ResponsesRequest, + StringOrArray, UserMessageContent, }; use crate::routers::header_utils; use crate::routers::{RouterTrait, WorkerManagement}; @@ -1930,6 +1930,18 @@ impl RouterTrait for PDRouter { self.execute_dual_dispatch(headers, body, context).await } + async fn route_responses( + &self, + _headers: Option<&HeaderMap>, + _body: &ResponsesRequest, + ) -> Response { + ( + StatusCode::NOT_IMPLEMENTED, + "Responses endpoint not implemented for PD router", + ) + .into_response() + } + async fn route_embeddings(&self, _headers: Option<&HeaderMap>, _body: Body) -> Response { todo!() } diff --git a/sgl-router/src/routers/http/router.rs b/sgl-router/src/routers/http/router.rs index 176c38602..f0dc4f3b5 100644 --- a/sgl-router/src/routers/http/router.rs +++ b/sgl-router/src/routers/http/router.rs @@ -6,7 +6,7 @@ use crate::core::{ use crate::metrics::RouterMetrics; use crate::policies::LoadBalancingPolicy; use crate::protocols::spec::{ - ChatCompletionRequest, CompletionRequest, GenerateRequest, GenerationRequest, + ChatCompletionRequest, CompletionRequest, GenerateRequest, GenerationRequest, ResponsesRequest, }; use crate::routers::header_utils; use crate::routers::{RouterTrait, WorkerManagement}; @@ -1210,6 +1210,15 @@ impl RouterTrait for Router { .await } + async fn route_responses( + &self, + headers: Option<&HeaderMap>, + body: &ResponsesRequest, + ) -> Response { + self.route_typed_request(headers, body, "/v1/responses") + .await + } + async fn route_embeddings(&self, _headers: Option<&HeaderMap>, _body: Body) -> Response { todo!() } diff --git a/sgl-router/src/routers/mod.rs b/sgl-router/src/routers/mod.rs index e610fedb3..6c12edbc8 100644 --- a/sgl-router/src/routers/mod.rs +++ b/sgl-router/src/routers/mod.rs @@ -9,7 +9,9 @@ use axum::{ }; use std::fmt::Debug; -use crate::protocols::spec::{ChatCompletionRequest, CompletionRequest, GenerateRequest}; +use crate::protocols::spec::{ + ChatCompletionRequest, CompletionRequest, GenerateRequest, ResponsesRequest, +}; pub mod factory; pub mod grpc; @@ -78,6 +80,13 @@ pub trait RouterTrait: Send + Sync + Debug + WorkerManagement { body: &CompletionRequest, ) -> Response; + /// Route a responses request + async fn route_responses( + &self, + headers: Option<&HeaderMap>, + body: &ResponsesRequest, + ) -> Response; + async fn route_embeddings(&self, headers: Option<&HeaderMap>, body: Body) -> Response; async fn route_rerank(&self, headers: Option<&HeaderMap>, body: Body) -> Response; diff --git a/sgl-router/src/server.rs b/sgl-router/src/server.rs index 2762f9765..9aca370ae 100644 --- a/sgl-router/src/server.rs +++ b/sgl-router/src/server.rs @@ -2,7 +2,9 @@ use crate::config::RouterConfig; use crate::logging::{self, LoggingConfig}; use crate::metrics::{self, PrometheusConfig}; use crate::middleware::TokenBucket; -use crate::protocols::spec::{ChatCompletionRequest, CompletionRequest, GenerateRequest}; +use crate::protocols::spec::{ + ChatCompletionRequest, CompletionRequest, GenerateRequest, ResponsesRequest, +}; use crate::reasoning_parser::ParserFactory; use crate::routers::{RouterFactory, RouterTrait}; use crate::service_discovery::{start_service_discovery, ServiceDiscoveryConfig}; @@ -150,6 +152,14 @@ async fn v1_completions( state.router.route_completion(Some(&headers), &body).await } +async fn v1_responses( + State(state): State>, + headers: http::HeaderMap, + Json(body): Json, +) -> Response { + state.router.route_responses(Some(&headers), &body).await +} + // Worker management endpoints async fn add_worker( State(state): State>, @@ -227,6 +237,7 @@ pub fn build_app( .route("/generate", post(generate)) .route("/v1/chat/completions", post(v1_chat_completions)) .route("/v1/completions", post(v1_completions)) + .route("/v1/responses", post(v1_responses)) .route_layer(axum::middleware::from_fn_with_state( app_state.clone(), crate::middleware::concurrency_limit_middleware, diff --git a/sgl-router/tests/api_endpoints_test.rs b/sgl-router/tests/api_endpoints_test.rs index 8b2e29714..09099a7b8 100644 --- a/sgl-router/tests/api_endpoints_test.rs +++ b/sgl-router/tests/api_endpoints_test.rs @@ -991,6 +991,91 @@ mod router_policy_tests { } } +#[cfg(test)] +mod responses_endpoint_tests { + use super::*; + + #[tokio::test] + async fn test_v1_responses_non_streaming() { + let ctx = TestContext::new(vec![MockWorkerConfig { + port: 18950, + worker_type: WorkerType::Regular, + health_status: HealthStatus::Healthy, + response_delay_ms: 0, + fail_rate: 0.0, + }]) + .await; + + let app = ctx.create_app().await; + + let payload = json!({ + "input": "Hello Responses API", + "model": "mock-model", + "stream": false + }); + + let req = Request::builder() + .method("POST") + .uri("/v1/responses") + .header(CONTENT_TYPE, "application/json") + .body(Body::from(serde_json::to_string(&payload).unwrap())) + .unwrap(); + + let resp = app.clone().oneshot(req).await.unwrap(); + assert_eq!(resp.status(), StatusCode::OK); + + let body = axum::body::to_bytes(resp.into_body(), usize::MAX) + .await + .unwrap(); + let body_json: serde_json::Value = serde_json::from_slice(&body).unwrap(); + assert_eq!(body_json["object"], "response"); + assert_eq!(body_json["status"], "completed"); + + ctx.shutdown().await; + } + + #[tokio::test] + async fn test_v1_responses_streaming() { + let ctx = TestContext::new(vec![MockWorkerConfig { + port: 18951, + worker_type: WorkerType::Regular, + health_status: HealthStatus::Healthy, + response_delay_ms: 0, + fail_rate: 0.0, + }]) + .await; + + let app = ctx.create_app().await; + + let payload = json!({ + "input": "Hello Responses API", + "model": "mock-model", + "stream": true + }); + + let req = Request::builder() + .method("POST") + .uri("/v1/responses") + .header(CONTENT_TYPE, "application/json") + .body(Body::from(serde_json::to_string(&payload).unwrap())) + .unwrap(); + + let resp = app.clone().oneshot(req).await.unwrap(); + assert_eq!(resp.status(), StatusCode::OK); + + // Check that content-type indicates SSE + let headers = resp.headers().clone(); + let ct = headers + .get("content-type") + .and_then(|v| v.to_str().ok()) + .unwrap_or(""); + assert!(ct.contains("text/event-stream")); + + // We don't fully consume the stream in this test harness. + ctx.shutdown().await; + } +} + #[cfg(test)] mod error_tests { use super::*; diff --git a/sgl-router/tests/common/mock_worker.rs b/sgl-router/tests/common/mock_worker.rs index 16d721607..a04a76008 100755 --- a/sgl-router/tests/common/mock_worker.rs +++ b/sgl-router/tests/common/mock_worker.rs @@ -81,6 +81,7 @@ impl MockWorker { .route("/generate", post(generate_handler)) .route("/v1/chat/completions", post(chat_completions_handler)) .route("/v1/completions", post(completions_handler)) + .route("/v1/responses", post(responses_handler)) .route("/flush_cache", post(flush_cache_handler)) .route("/v1/models", get(v1_models_handler)) .with_state(config); @@ -548,6 +549,91 @@ async fn completions_handler( } } +async fn responses_handler( + State(config): State>>, + Json(payload): Json, +) -> Response { + let config = config.read().await; + + if should_fail(&config).await { + return ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({ + "error": { + "message": "Random failure for testing", + "type": "internal_error", + "code": "internal_error" + } + })), + ) + .into_response(); + } + + if config.response_delay_ms > 0 { + tokio::time::sleep(tokio::time::Duration::from_millis(config.response_delay_ms)).await; + } + + let is_stream = payload + .get("stream") + .and_then(|v| v.as_bool()) + .unwrap_or(false); + + let timestamp = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_secs() as i64; + + if is_stream { + let request_id = format!("resp-{}", Uuid::new_v4()); + + let stream = stream::once(async move { + let chunk = json!({ + "id": request_id, + "object": "response", + "created_at": timestamp, + "model": "mock-model", + "status": "in_progress", + "output": [{ + "type": "message", + "role": "assistant", + "content": [{ + "type": "output_text", + "text": "This is a mock responses streamed output." + }] + }] + }); + Ok::<_, Infallible>(Event::default().data(chunk.to_string())) + }) + .chain(stream::once(async { Ok(Event::default().data("[DONE]")) })); + + Sse::new(stream) + .keep_alive(KeepAlive::default()) + .into_response() + } else { + Json(json!({ + "id": format!("resp-{}", Uuid::new_v4()), + "object": "response", + "created_at": timestamp, + "model": "mock-model", + "output": [{ + "type": "message", + "role": "assistant", + "content": [{ + "type": "output_text", + "text": "This is a mock responses output." + }] + }], + "status": "completed", + "usage": { + "input_tokens": 10, + "output_tokens": 5, + "total_tokens": 15 + } + })) + .into_response() + } +} + async fn flush_cache_handler(State(config): State>>) -> Response { let config = config.read().await;