[router] Support streaming for Openai Router Response api (#10822)

This commit is contained in:
Keyang Ru
2025-09-23 14:56:28 -07:00
committed by GitHub
parent 312bfc4c95
commit f4e3ebeb05
2 changed files with 710 additions and 19 deletions

View File

@@ -17,11 +17,14 @@ use axum::{
http::{header::CONTENT_TYPE, HeaderMap, HeaderValue, StatusCode},
response::{IntoResponse, Response},
};
use bytes::Bytes;
use futures_util::StreamExt;
use serde_json::{json, to_value, Value};
use std::{
any::Any,
borrow::Cow,
collections::HashMap,
io,
sync::atomic::{AtomicBool, Ordering},
};
use tokio::sync::mpsc;
@@ -51,6 +54,142 @@ impl std::fmt::Debug for OpenAIRouter {
}
}
/// Helper that parses SSE frames from the OpenAI responses stream and
/// accumulates enough information to persist the final response locally.
struct StreamingResponseAccumulator {
/// The initial `response.created` payload (if emitted).
initial_response: Option<Value>,
/// The final `response.completed` payload (if emitted).
completed_response: Option<Value>,
/// Collected output items keyed by the upstream output index, used when
/// a final response payload is absent and we need to synthesize one.
output_items: Vec<(usize, Value)>,
/// Captured error payload (if the upstream stream fails midway).
encountered_error: Option<Value>,
}
impl StreamingResponseAccumulator {
fn new() -> Self {
Self {
initial_response: None,
completed_response: None,
output_items: Vec::new(),
encountered_error: None,
}
}
/// Feed the accumulator with the next SSE chunk.
fn ingest_block(&mut self, block: &str) {
if block.trim().is_empty() {
return;
}
self.process_block(block);
}
/// Consume the accumulator and produce the best-effort final response value.
fn into_final_response(mut self) -> Option<Value> {
if self.completed_response.is_some() {
return self.completed_response;
}
self.build_fallback_response()
}
fn encountered_error(&self) -> Option<&Value> {
self.encountered_error.as_ref()
}
fn process_block(&mut self, block: &str) {
let trimmed = block.trim();
if trimmed.is_empty() {
return;
}
let mut event_name: Option<String> = None;
let mut data_lines: Vec<String> = Vec::new();
for line in trimmed.lines() {
if let Some(rest) = line.strip_prefix("event:") {
event_name = Some(rest.trim().to_string());
} else if let Some(rest) = line.strip_prefix("data:") {
data_lines.push(rest.trim_start().to_string());
}
}
let data_payload = data_lines.join("\n");
if data_payload.is_empty() {
return;
}
self.handle_event(event_name.as_deref(), &data_payload);
}
fn handle_event(&mut self, event_name: Option<&str>, data_payload: &str) {
let parsed: Value = match serde_json::from_str(data_payload) {
Ok(value) => value,
Err(err) => {
warn!("Failed to parse streaming event JSON: {}", err);
return;
}
};
let event_type = event_name
.map(|s| s.to_string())
.or_else(|| {
parsed
.get("type")
.and_then(|v| v.as_str())
.map(|s| s.to_string())
})
.unwrap_or_default();
match event_type.as_str() {
"response.created" => {
if let Some(response) = parsed.get("response") {
self.initial_response = Some(response.clone());
}
}
"response.completed" => {
if let Some(response) = parsed.get("response") {
self.completed_response = Some(response.clone());
}
}
"response.output_item.done" => {
if let (Some(index), Some(item)) = (
parsed
.get("output_index")
.and_then(|v| v.as_u64())
.map(|v| v as usize),
parsed.get("item"),
) {
self.output_items.push((index, item.clone()));
}
}
"response.error" => {
self.encountered_error = Some(parsed);
}
_ => {}
}
}
fn build_fallback_response(&mut self) -> Option<Value> {
let mut response = self.initial_response.clone()?;
if let Some(obj) = response.as_object_mut() {
obj.insert("status".to_string(), Value::String("completed".to_string()));
self.output_items.sort_by_key(|(index, _)| *index);
let outputs: Vec<Value> = self
.output_items
.iter()
.map(|(_, item)| item.clone())
.collect();
obj.insert("output".to_string(), Value::Array(outputs));
}
Some(response)
}
}
impl OpenAIRouter {
/// Create a new OpenAI router
pub async fn new(
@@ -203,17 +342,159 @@ impl OpenAIRouter {
async fn handle_streaming_response(
&self,
_url: String,
_headers: Option<&HeaderMap>,
_payload: Value,
_original_body: &ResponsesRequest,
_original_previous_response_id: Option<String>,
url: String,
headers: Option<&HeaderMap>,
payload: Value,
original_body: &ResponsesRequest,
original_previous_response_id: Option<String>,
) -> Response {
(
StatusCode::NOT_IMPLEMENTED,
"Streaming responses not yet implemented",
)
.into_response()
let mut request_builder = self.client.post(&url).json(&payload);
if let Some(headers) = headers {
request_builder = apply_request_headers(headers, request_builder, true);
}
request_builder = request_builder.header("Accept", "text/event-stream");
let response = match request_builder.send().await {
Ok(resp) => resp,
Err(err) => {
self.circuit_breaker.record_failure();
return (
StatusCode::BAD_GATEWAY,
format!("Failed to forward request to OpenAI: {}", err),
)
.into_response();
}
};
let status = response.status();
let status_code =
StatusCode::from_u16(status.as_u16()).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
if !status.is_success() {
self.circuit_breaker.record_failure();
let error_body = match response.text().await {
Ok(body) => body,
Err(err) => format!("Failed to read upstream error body: {}", err),
};
return (status_code, error_body).into_response();
}
self.circuit_breaker.record_success();
let preserved_headers = preserve_response_headers(response.headers());
let mut upstream_stream = response.bytes_stream();
let (tx, rx) = mpsc::unbounded_channel::<Result<Bytes, io::Error>>();
let should_store = original_body.store;
let storage = self.response_storage.clone();
let original_request = original_body.clone();
let previous_response_id = original_previous_response_id.clone();
tokio::spawn(async move {
let mut accumulator = StreamingResponseAccumulator::new();
let mut upstream_failed = false;
let mut receiver_connected = true;
let mut pending = String::new();
while let Some(chunk_result) = upstream_stream.next().await {
match chunk_result {
Ok(chunk) => {
let chunk_text = match std::str::from_utf8(&chunk) {
Ok(text) => Cow::Borrowed(text),
Err(_) => Cow::Owned(String::from_utf8_lossy(&chunk).to_string()),
};
pending.push_str(&chunk_text.replace("\r\n", "\n"));
while let Some(pos) = pending.find("\n\n") {
let raw_block = pending[..pos].to_string();
pending.drain(..pos + 2);
if raw_block.trim().is_empty() {
continue;
}
let block_cow = if let Some(modified) = Self::rewrite_streaming_block(
raw_block.as_str(),
&original_request,
previous_response_id.as_deref(),
) {
Cow::Owned(modified)
} else {
Cow::Borrowed(raw_block.as_str())
};
if should_store {
accumulator.ingest_block(block_cow.as_ref());
}
if receiver_connected {
let chunk_to_send = format!("{}\n\n", block_cow);
if tx.send(Ok(Bytes::from(chunk_to_send))).is_err() {
receiver_connected = false;
}
}
if !receiver_connected && !should_store {
break;
}
}
if !receiver_connected && !should_store {
break;
}
}
Err(err) => {
upstream_failed = true;
let io_err = io::Error::other(err);
let _ = tx.send(Err(io_err));
break;
}
}
}
if should_store && !upstream_failed {
if !pending.trim().is_empty() {
accumulator.ingest_block(&pending);
}
let encountered_error = accumulator.encountered_error().cloned();
if let Some(mut response_json) = accumulator.into_final_response() {
Self::patch_streaming_response_json(
&mut response_json,
&original_request,
previous_response_id.as_deref(),
);
if let Err(err) =
Self::store_response_impl(&storage, &response_json, &original_request).await
{
warn!("Failed to store streaming response: {}", err);
}
} else if let Some(error_payload) = encountered_error {
warn!("Upstream streaming error payload: {}", error_payload);
} else {
warn!("Streaming completed without a final response payload");
}
}
});
let body_stream = UnboundedReceiverStream::new(rx);
let mut response = Response::new(Body::from_stream(body_stream));
*response.status_mut() = status_code;
let headers_mut = response.headers_mut();
for (name, value) in preserved_headers.iter() {
headers_mut.insert(name, value.clone());
}
if !headers_mut.contains_key(CONTENT_TYPE) {
headers_mut.insert(CONTENT_TYPE, HeaderValue::from_static("text/event-stream"));
}
response
}
async fn store_response_internal(
@@ -300,6 +581,172 @@ impl OpenAIRouter {
.map_err(|e| format!("Failed to store response: {}", e))
}
fn patch_streaming_response_json(
response_json: &mut Value,
original_body: &ResponsesRequest,
original_previous_response_id: Option<&str>,
) {
if let Some(obj) = response_json.as_object_mut() {
if let Some(prev_id) = original_previous_response_id {
let should_insert = obj
.get("previous_response_id")
.map(|v| v.is_null() || v.as_str().map(|s| s.is_empty()).unwrap_or(false))
.unwrap_or(true);
if should_insert {
obj.insert(
"previous_response_id".to_string(),
Value::String(prev_id.to_string()),
);
}
}
if !obj.contains_key("instructions")
|| obj
.get("instructions")
.map(|v| v.is_null())
.unwrap_or(false)
{
if let Some(instructions) = &original_body.instructions {
obj.insert(
"instructions".to_string(),
Value::String(instructions.clone()),
);
}
}
if !obj.contains_key("metadata")
|| obj.get("metadata").map(|v| v.is_null()).unwrap_or(false)
{
if let Some(metadata) = &original_body.metadata {
let metadata_map: serde_json::Map<String, Value> = metadata
.iter()
.map(|(k, v)| (k.clone(), v.clone()))
.collect();
obj.insert("metadata".to_string(), Value::Object(metadata_map));
}
}
obj.insert("store".to_string(), Value::Bool(original_body.store));
if obj
.get("model")
.and_then(|v| v.as_str())
.map(|s| s.is_empty())
.unwrap_or(true)
{
if let Some(model) = &original_body.model {
obj.insert("model".to_string(), Value::String(model.clone()));
}
}
if obj.get("user").map(|v| v.is_null()).unwrap_or(false) {
if let Some(user) = &original_body.user {
obj.insert("user".to_string(), Value::String(user.clone()));
}
}
}
}
fn rewrite_streaming_block(
block: &str,
original_body: &ResponsesRequest,
original_previous_response_id: Option<&str>,
) -> Option<String> {
let trimmed = block.trim();
if trimmed.is_empty() {
return None;
}
let mut data_lines: Vec<String> = Vec::new();
for line in trimmed.lines() {
if line.starts_with("data:") {
data_lines.push(line.trim_start_matches("data:").trim_start().to_string());
}
}
if data_lines.is_empty() {
return None;
}
let payload = data_lines.join("\n");
let mut parsed: Value = match serde_json::from_str(&payload) {
Ok(value) => value,
Err(err) => {
warn!("Failed to parse streaming JSON payload: {}", err);
return None;
}
};
let event_type = parsed
.get("type")
.and_then(|v| v.as_str())
.unwrap_or_default();
let should_patch = matches!(
event_type,
"response.created" | "response.in_progress" | "response.completed"
);
if !should_patch {
return None;
}
let mut changed = false;
if let Some(response_obj) = parsed.get_mut("response").and_then(|v| v.as_object_mut()) {
let desired_store = Value::Bool(original_body.store);
if response_obj.get("store") != Some(&desired_store) {
response_obj.insert("store".to_string(), desired_store);
changed = true;
}
if let Some(prev_id) = original_previous_response_id {
let needs_previous = response_obj
.get("previous_response_id")
.map(|v| v.is_null() || v.as_str().map(|s| s.is_empty()).unwrap_or(false))
.unwrap_or(true);
if needs_previous {
response_obj.insert(
"previous_response_id".to_string(),
Value::String(prev_id.to_string()),
);
changed = true;
}
}
}
if !changed {
return None;
}
let new_payload = match serde_json::to_string(&parsed) {
Ok(json) => json,
Err(err) => {
warn!("Failed to serialize modified streaming payload: {}", err);
return None;
}
};
let mut rebuilt_lines = Vec::new();
let mut data_written = false;
for line in trimmed.lines() {
if line.starts_with("data:") {
if !data_written {
rebuilt_lines.push(format!("data: {}", new_payload));
data_written = true;
}
} else {
rebuilt_lines.push(line.to_string());
}
}
if !data_written {
rebuilt_lines.push(format!("data: {}", new_payload));
}
Some(rebuilt_lines.join("\n"))
}
fn extract_primary_output_text(response_json: &Value) -> Option<String> {
if let Some(items) = response_json.get("output").and_then(|v| v.as_array()) {
for item in items {
@@ -595,14 +1042,6 @@ impl super::super::RouterTrait for OpenAIRouter {
"openai_responses_request"
);
if body.stream {
return (
StatusCode::NOT_IMPLEMENTED,
"Streaming responses not yet implemented",
)
.into_response();
}
// Clone the body and override model if needed
let mut request_body = body.clone();
if let Some(model) = model_id {