diff --git a/sgl-router/src/routers/http/openai_router.rs b/sgl-router/src/routers/http/openai_router.rs index 2187ce898..c5ba55af9 100644 --- a/sgl-router/src/routers/http/openai_router.rs +++ b/sgl-router/src/routers/http/openai_router.rs @@ -17,11 +17,14 @@ use axum::{ http::{header::CONTENT_TYPE, HeaderMap, HeaderValue, StatusCode}, response::{IntoResponse, Response}, }; +use bytes::Bytes; use futures_util::StreamExt; use serde_json::{json, to_value, Value}; use std::{ any::Any, + borrow::Cow, collections::HashMap, + io, sync::atomic::{AtomicBool, Ordering}, }; use tokio::sync::mpsc; @@ -51,6 +54,142 @@ impl std::fmt::Debug for OpenAIRouter { } } +/// Helper that parses SSE frames from the OpenAI responses stream and +/// accumulates enough information to persist the final response locally. +struct StreamingResponseAccumulator { + /// The initial `response.created` payload (if emitted). + initial_response: Option, + /// The final `response.completed` payload (if emitted). + completed_response: Option, + /// Collected output items keyed by the upstream output index, used when + /// a final response payload is absent and we need to synthesize one. + output_items: Vec<(usize, Value)>, + /// Captured error payload (if the upstream stream fails midway). + encountered_error: Option, +} + +impl StreamingResponseAccumulator { + fn new() -> Self { + Self { + initial_response: None, + completed_response: None, + output_items: Vec::new(), + encountered_error: None, + } + } + + /// Feed the accumulator with the next SSE chunk. + fn ingest_block(&mut self, block: &str) { + if block.trim().is_empty() { + return; + } + self.process_block(block); + } + + /// Consume the accumulator and produce the best-effort final response value. + fn into_final_response(mut self) -> Option { + if self.completed_response.is_some() { + return self.completed_response; + } + + self.build_fallback_response() + } + + fn encountered_error(&self) -> Option<&Value> { + self.encountered_error.as_ref() + } + fn process_block(&mut self, block: &str) { + let trimmed = block.trim(); + if trimmed.is_empty() { + return; + } + + let mut event_name: Option = None; + let mut data_lines: Vec = Vec::new(); + + for line in trimmed.lines() { + if let Some(rest) = line.strip_prefix("event:") { + event_name = Some(rest.trim().to_string()); + } else if let Some(rest) = line.strip_prefix("data:") { + data_lines.push(rest.trim_start().to_string()); + } + } + + let data_payload = data_lines.join("\n"); + if data_payload.is_empty() { + return; + } + + self.handle_event(event_name.as_deref(), &data_payload); + } + + fn handle_event(&mut self, event_name: Option<&str>, data_payload: &str) { + let parsed: Value = match serde_json::from_str(data_payload) { + Ok(value) => value, + Err(err) => { + warn!("Failed to parse streaming event JSON: {}", err); + return; + } + }; + + let event_type = event_name + .map(|s| s.to_string()) + .or_else(|| { + parsed + .get("type") + .and_then(|v| v.as_str()) + .map(|s| s.to_string()) + }) + .unwrap_or_default(); + + match event_type.as_str() { + "response.created" => { + if let Some(response) = parsed.get("response") { + self.initial_response = Some(response.clone()); + } + } + "response.completed" => { + if let Some(response) = parsed.get("response") { + self.completed_response = Some(response.clone()); + } + } + "response.output_item.done" => { + if let (Some(index), Some(item)) = ( + parsed + .get("output_index") + .and_then(|v| v.as_u64()) + .map(|v| v as usize), + parsed.get("item"), + ) { + self.output_items.push((index, item.clone())); + } + } + "response.error" => { + self.encountered_error = Some(parsed); + } + _ => {} + } + } + + fn build_fallback_response(&mut self) -> Option { + let mut response = self.initial_response.clone()?; + + if let Some(obj) = response.as_object_mut() { + obj.insert("status".to_string(), Value::String("completed".to_string())); + + self.output_items.sort_by_key(|(index, _)| *index); + let outputs: Vec = self + .output_items + .iter() + .map(|(_, item)| item.clone()) + .collect(); + obj.insert("output".to_string(), Value::Array(outputs)); + } + + Some(response) + } +} + impl OpenAIRouter { /// Create a new OpenAI router pub async fn new( @@ -203,17 +342,159 @@ impl OpenAIRouter { async fn handle_streaming_response( &self, - _url: String, - _headers: Option<&HeaderMap>, - _payload: Value, - _original_body: &ResponsesRequest, - _original_previous_response_id: Option, + url: String, + headers: Option<&HeaderMap>, + payload: Value, + original_body: &ResponsesRequest, + original_previous_response_id: Option, ) -> Response { - ( - StatusCode::NOT_IMPLEMENTED, - "Streaming responses not yet implemented", - ) - .into_response() + let mut request_builder = self.client.post(&url).json(&payload); + + if let Some(headers) = headers { + request_builder = apply_request_headers(headers, request_builder, true); + } + + request_builder = request_builder.header("Accept", "text/event-stream"); + + let response = match request_builder.send().await { + Ok(resp) => resp, + Err(err) => { + self.circuit_breaker.record_failure(); + return ( + StatusCode::BAD_GATEWAY, + format!("Failed to forward request to OpenAI: {}", err), + ) + .into_response(); + } + }; + + let status = response.status(); + let status_code = + StatusCode::from_u16(status.as_u16()).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR); + + if !status.is_success() { + self.circuit_breaker.record_failure(); + let error_body = match response.text().await { + Ok(body) => body, + Err(err) => format!("Failed to read upstream error body: {}", err), + }; + return (status_code, error_body).into_response(); + } + + self.circuit_breaker.record_success(); + + let preserved_headers = preserve_response_headers(response.headers()); + let mut upstream_stream = response.bytes_stream(); + + let (tx, rx) = mpsc::unbounded_channel::>(); + + let should_store = original_body.store; + let storage = self.response_storage.clone(); + let original_request = original_body.clone(); + let previous_response_id = original_previous_response_id.clone(); + + tokio::spawn(async move { + let mut accumulator = StreamingResponseAccumulator::new(); + let mut upstream_failed = false; + let mut receiver_connected = true; + let mut pending = String::new(); + + while let Some(chunk_result) = upstream_stream.next().await { + match chunk_result { + Ok(chunk) => { + let chunk_text = match std::str::from_utf8(&chunk) { + Ok(text) => Cow::Borrowed(text), + Err(_) => Cow::Owned(String::from_utf8_lossy(&chunk).to_string()), + }; + + pending.push_str(&chunk_text.replace("\r\n", "\n")); + + while let Some(pos) = pending.find("\n\n") { + let raw_block = pending[..pos].to_string(); + pending.drain(..pos + 2); + + if raw_block.trim().is_empty() { + continue; + } + + let block_cow = if let Some(modified) = Self::rewrite_streaming_block( + raw_block.as_str(), + &original_request, + previous_response_id.as_deref(), + ) { + Cow::Owned(modified) + } else { + Cow::Borrowed(raw_block.as_str()) + }; + + if should_store { + accumulator.ingest_block(block_cow.as_ref()); + } + + if receiver_connected { + let chunk_to_send = format!("{}\n\n", block_cow); + if tx.send(Ok(Bytes::from(chunk_to_send))).is_err() { + receiver_connected = false; + } + } + + if !receiver_connected && !should_store { + break; + } + } + + if !receiver_connected && !should_store { + break; + } + } + Err(err) => { + upstream_failed = true; + let io_err = io::Error::other(err); + let _ = tx.send(Err(io_err)); + break; + } + } + } + + if should_store && !upstream_failed { + if !pending.trim().is_empty() { + accumulator.ingest_block(&pending); + } + let encountered_error = accumulator.encountered_error().cloned(); + if let Some(mut response_json) = accumulator.into_final_response() { + Self::patch_streaming_response_json( + &mut response_json, + &original_request, + previous_response_id.as_deref(), + ); + + if let Err(err) = + Self::store_response_impl(&storage, &response_json, &original_request).await + { + warn!("Failed to store streaming response: {}", err); + } + } else if let Some(error_payload) = encountered_error { + warn!("Upstream streaming error payload: {}", error_payload); + } else { + warn!("Streaming completed without a final response payload"); + } + } + }); + + let body_stream = UnboundedReceiverStream::new(rx); + let mut response = Response::new(Body::from_stream(body_stream)); + *response.status_mut() = status_code; + + let headers_mut = response.headers_mut(); + for (name, value) in preserved_headers.iter() { + headers_mut.insert(name, value.clone()); + } + + if !headers_mut.contains_key(CONTENT_TYPE) { + headers_mut.insert(CONTENT_TYPE, HeaderValue::from_static("text/event-stream")); + } + + response } async fn store_response_internal( @@ -300,6 +581,172 @@ impl OpenAIRouter { .map_err(|e| format!("Failed to store response: {}", e)) } + fn patch_streaming_response_json( + response_json: &mut Value, + original_body: &ResponsesRequest, + original_previous_response_id: Option<&str>, + ) { + if let Some(obj) = response_json.as_object_mut() { + if let Some(prev_id) = original_previous_response_id { + let should_insert = obj + .get("previous_response_id") + .map(|v| v.is_null() || v.as_str().map(|s| s.is_empty()).unwrap_or(false)) + .unwrap_or(true); + if should_insert { + obj.insert( + "previous_response_id".to_string(), + Value::String(prev_id.to_string()), + ); + } + } + + if !obj.contains_key("instructions") + || obj + .get("instructions") + .map(|v| v.is_null()) + .unwrap_or(false) + { + if let Some(instructions) = &original_body.instructions { + obj.insert( + "instructions".to_string(), + Value::String(instructions.clone()), + ); + } + } + + if !obj.contains_key("metadata") + || obj.get("metadata").map(|v| v.is_null()).unwrap_or(false) + { + if let Some(metadata) = &original_body.metadata { + let metadata_map: serde_json::Map = metadata + .iter() + .map(|(k, v)| (k.clone(), v.clone())) + .collect(); + obj.insert("metadata".to_string(), Value::Object(metadata_map)); + } + } + + obj.insert("store".to_string(), Value::Bool(original_body.store)); + + if obj + .get("model") + .and_then(|v| v.as_str()) + .map(|s| s.is_empty()) + .unwrap_or(true) + { + if let Some(model) = &original_body.model { + obj.insert("model".to_string(), Value::String(model.clone())); + } + } + + if obj.get("user").map(|v| v.is_null()).unwrap_or(false) { + if let Some(user) = &original_body.user { + obj.insert("user".to_string(), Value::String(user.clone())); + } + } + } + } + + fn rewrite_streaming_block( + block: &str, + original_body: &ResponsesRequest, + original_previous_response_id: Option<&str>, + ) -> Option { + let trimmed = block.trim(); + if trimmed.is_empty() { + return None; + } + + let mut data_lines: Vec = Vec::new(); + + for line in trimmed.lines() { + if line.starts_with("data:") { + data_lines.push(line.trim_start_matches("data:").trim_start().to_string()); + } + } + + if data_lines.is_empty() { + return None; + } + + let payload = data_lines.join("\n"); + let mut parsed: Value = match serde_json::from_str(&payload) { + Ok(value) => value, + Err(err) => { + warn!("Failed to parse streaming JSON payload: {}", err); + return None; + } + }; + + let event_type = parsed + .get("type") + .and_then(|v| v.as_str()) + .unwrap_or_default(); + + let should_patch = matches!( + event_type, + "response.created" | "response.in_progress" | "response.completed" + ); + + if !should_patch { + return None; + } + + let mut changed = false; + if let Some(response_obj) = parsed.get_mut("response").and_then(|v| v.as_object_mut()) { + let desired_store = Value::Bool(original_body.store); + if response_obj.get("store") != Some(&desired_store) { + response_obj.insert("store".to_string(), desired_store); + changed = true; + } + + if let Some(prev_id) = original_previous_response_id { + let needs_previous = response_obj + .get("previous_response_id") + .map(|v| v.is_null() || v.as_str().map(|s| s.is_empty()).unwrap_or(false)) + .unwrap_or(true); + + if needs_previous { + response_obj.insert( + "previous_response_id".to_string(), + Value::String(prev_id.to_string()), + ); + changed = true; + } + } + } + + if !changed { + return None; + } + + let new_payload = match serde_json::to_string(&parsed) { + Ok(json) => json, + Err(err) => { + warn!("Failed to serialize modified streaming payload: {}", err); + return None; + } + }; + + let mut rebuilt_lines = Vec::new(); + let mut data_written = false; + for line in trimmed.lines() { + if line.starts_with("data:") { + if !data_written { + rebuilt_lines.push(format!("data: {}", new_payload)); + data_written = true; + } + } else { + rebuilt_lines.push(line.to_string()); + } + } + + if !data_written { + rebuilt_lines.push(format!("data: {}", new_payload)); + } + + Some(rebuilt_lines.join("\n")) + } fn extract_primary_output_text(response_json: &Value) -> Option { if let Some(items) = response_json.get("output").and_then(|v| v.as_array()) { for item in items { @@ -595,14 +1042,6 @@ impl super::super::RouterTrait for OpenAIRouter { "openai_responses_request" ); - if body.stream { - return ( - StatusCode::NOT_IMPLEMENTED, - "Streaming responses not yet implemented", - ) - .into_response(); - } - // Clone the body and override model if needed let mut request_body = body.clone(); if let Some(model) = model_id { diff --git a/sgl-router/tests/test_openai_routing.rs b/sgl-router/tests/test_openai_routing.rs index 624ab9080..55ebce64f 100644 --- a/sgl-router/tests/test_openai_routing.rs +++ b/sgl-router/tests/test_openai_routing.rs @@ -4,24 +4,27 @@ use axum::{ body::Body, extract::Request, http::{Method, StatusCode}, + response::Response, routing::post, Json, Router, }; use serde_json::json; use sglang_router_rs::{ config::{RouterConfig, RoutingMode}, - data_connector::{MemoryResponseStorage, ResponseId, ResponseStorage}, + data_connector::{MemoryResponseStorage, ResponseId, ResponseStorage, StoredResponse}, protocols::spec::{ ChatCompletionRequest, ChatMessage, CompletionRequest, GenerateRequest, ResponseInput, ResponsesGetParams, ResponsesRequest, UserMessageContent, }, routers::{openai_router::OpenAIRouter, RouterTrait}, }; +use std::collections::HashMap; use std::sync::{ atomic::{AtomicUsize, Ordering}, Arc, }; use tokio::net::TcpListener; +use tokio::time::{sleep, Duration}; use tower::ServiceExt; mod common; @@ -309,6 +312,255 @@ async fn test_openai_router_responses_with_mock() { server.abort(); } +#[tokio::test] +async fn test_openai_router_responses_streaming_with_mock() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + + let sse_handler = post(|Json(_request): Json| async move { + let response_id = "resp_stream_123"; + let message_id = "msg_stream_123"; + let final_text = "Once upon a streamed unicorn adventure."; + + let events = vec![ + ( + "response.created", + json!({ + "type": "response.created", + "sequence_number": 0, + "response": { + "id": response_id, + "object": "response", + "created_at": 1_700_000_500, + "status": "in_progress", + "model": "", + "output": [], + "parallel_tool_calls": true, + "previous_response_id": null, + "reasoning": null, + "store": false, + "temperature": 1.0, + "text": {"format": {"type": "text"}}, + "tool_choice": "auto", + "tools": [], + "top_p": 1.0, + "truncation": "disabled", + "usage": null, + "metadata": null + } + }), + ), + ( + "response.output_item.added", + json!({ + "type": "response.output_item.added", + "sequence_number": 1, + "output_index": 0, + "item": { + "id": message_id, + "type": "message", + "role": "assistant", + "status": "in_progress", + "content": [] + } + }), + ), + ( + "response.output_text.delta", + json!({ + "type": "response.output_text.delta", + "sequence_number": 2, + "item_id": message_id, + "output_index": 0, + "content_index": 0, + "delta": "Once upon a streamed unicorn adventure.", + "logprobs": [] + }), + ), + ( + "response.output_text.done", + json!({ + "type": "response.output_text.done", + "sequence_number": 3, + "item_id": message_id, + "output_index": 0, + "content_index": 0, + "text": final_text, + "logprobs": [] + }), + ), + ( + "response.output_item.done", + json!({ + "type": "response.output_item.done", + "sequence_number": 4, + "output_index": 0, + "item": { + "id": message_id, + "type": "message", + "role": "assistant", + "status": "completed", + "content": [{ + "type": "output_text", + "text": final_text, + "annotations": [], + "logprobs": [] + }] + } + }), + ), + ( + "response.completed", + json!({ + "type": "response.completed", + "sequence_number": 5, + "response": { + "id": response_id, + "object": "response", + "created_at": 1_700_000_500, + "status": "completed", + "model": "", + "output": [{ + "id": message_id, + "type": "message", + "role": "assistant", + "status": "completed", + "content": [{ + "type": "output_text", + "text": final_text, + "annotations": [], + "logprobs": [] + }] + }], + "parallel_tool_calls": true, + "previous_response_id": null, + "reasoning": null, + "store": false, + "temperature": 1.0, + "text": {"format": {"type": "text"}}, + "tool_choice": "auto", + "tools": [], + "top_p": 1.0, + "truncation": "disabled", + "usage": { + "input_tokens": 10, + "input_tokens_details": {"cached_tokens": 0}, + "output_tokens": 20, + "output_tokens_details": {"reasoning_tokens": 5}, + "total_tokens": 30 + }, + "metadata": null, + "instructions": null, + "user": null + } + }), + ), + ]; + + let sse_payload = events + .into_iter() + .map(|(event, data)| format!("event: {}\ndata: {}\n\n", event, data)) + .collect::(); + + Response::builder() + .status(StatusCode::OK) + .header("content-type", "text/event-stream") + .body(Body::from(sse_payload)) + .unwrap() + }); + + let app = Router::new().route("/v1/responses", sse_handler); + + let server = tokio::spawn(async move { + axum::serve(listener, app).await.unwrap(); + }); + + let base_url = format!("http://{}", addr); + let storage = Arc::new(MemoryResponseStorage::new()); + + // Seed a previous response so previous_response_id logic has data to pull from. + let mut previous = StoredResponse::new( + "Earlier bedtime question".to_string(), + "Earlier answer".to_string(), + None, + ); + previous.id = ResponseId::from_string("resp_prev_chain".to_string()); + storage.store_response(previous).await.unwrap(); + + let router = OpenAIRouter::new(base_url, None, storage.clone()) + .await + .unwrap(); + + let mut metadata = HashMap::new(); + metadata.insert("topic".to_string(), json!("unicorns")); + + let request = ResponsesRequest { + model: Some("gpt-5-nano".to_string()), + input: ResponseInput::Text("Tell me a bedtime story.".to_string()), + instructions: Some("Be kind".to_string()), + metadata: Some(metadata), + previous_response_id: Some("resp_prev_chain".to_string()), + store: true, + stream: true, + ..Default::default() + }; + + let response = router.route_responses(None, &request, None).await; + assert_eq!(response.status(), StatusCode::OK); + + let headers = response.headers(); + let ct = headers + .get("content-type") + .unwrap() + .to_str() + .unwrap() + .to_ascii_lowercase(); + assert!(ct.contains("text/event-stream")); + + let response_body = axum::body::to_bytes(response.into_body(), usize::MAX) + .await + .unwrap(); + let body_text = String::from_utf8(response_body.to_vec()).unwrap(); + assert!(body_text.contains("response.completed")); + assert!(body_text.contains("Once upon a streamed unicorn adventure.")); + + // Wait for the storage task to persist the streaming response. + let target_id = ResponseId::from_string("resp_stream_123".to_string()); + let stored = loop { + if let Some(resp) = storage.get_response(&target_id).await.unwrap() { + break resp; + } + sleep(Duration::from_millis(10)).await; + }; + + assert_eq!(stored.input, "Tell me a bedtime story."); + assert_eq!(stored.output, "Once upon a streamed unicorn adventure."); + assert_eq!( + stored + .previous_response_id + .as_ref() + .expect("previous_response_id missing") + .0, + "resp_prev_chain" + ); + assert_eq!(stored.metadata.get("topic"), Some(&json!("unicorns"))); + assert_eq!(stored.instructions.as_deref(), Some("Be kind")); + assert_eq!(stored.model.as_deref(), Some("gpt-5-nano")); + assert_eq!(stored.user, None); + assert_eq!(stored.raw_response["store"], json!(true)); + assert_eq!( + stored.raw_response["previous_response_id"].as_str(), + Some("resp_prev_chain") + ); + assert_eq!(stored.raw_response["metadata"]["topic"], json!("unicorns")); + assert_eq!( + stored.raw_response["instructions"].as_str(), + Some("Be kind") + ); + + server.abort(); +} + /// Test router factory with OpenAI routing mode #[tokio::test] async fn test_router_factory_openai_mode() {