[router] Support streaming for Openai Router Response api (#10822)
This commit is contained in:
@@ -17,11 +17,14 @@ use axum::{
|
||||
http::{header::CONTENT_TYPE, HeaderMap, HeaderValue, StatusCode},
|
||||
response::{IntoResponse, Response},
|
||||
};
|
||||
use bytes::Bytes;
|
||||
use futures_util::StreamExt;
|
||||
use serde_json::{json, to_value, Value};
|
||||
use std::{
|
||||
any::Any,
|
||||
borrow::Cow,
|
||||
collections::HashMap,
|
||||
io,
|
||||
sync::atomic::{AtomicBool, Ordering},
|
||||
};
|
||||
use tokio::sync::mpsc;
|
||||
@@ -51,6 +54,142 @@ impl std::fmt::Debug for OpenAIRouter {
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper that parses SSE frames from the OpenAI responses stream and
|
||||
/// accumulates enough information to persist the final response locally.
|
||||
struct StreamingResponseAccumulator {
|
||||
/// The initial `response.created` payload (if emitted).
|
||||
initial_response: Option<Value>,
|
||||
/// The final `response.completed` payload (if emitted).
|
||||
completed_response: Option<Value>,
|
||||
/// Collected output items keyed by the upstream output index, used when
|
||||
/// a final response payload is absent and we need to synthesize one.
|
||||
output_items: Vec<(usize, Value)>,
|
||||
/// Captured error payload (if the upstream stream fails midway).
|
||||
encountered_error: Option<Value>,
|
||||
}
|
||||
|
||||
impl StreamingResponseAccumulator {
|
||||
fn new() -> Self {
|
||||
Self {
|
||||
initial_response: None,
|
||||
completed_response: None,
|
||||
output_items: Vec::new(),
|
||||
encountered_error: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Feed the accumulator with the next SSE chunk.
|
||||
fn ingest_block(&mut self, block: &str) {
|
||||
if block.trim().is_empty() {
|
||||
return;
|
||||
}
|
||||
self.process_block(block);
|
||||
}
|
||||
|
||||
/// Consume the accumulator and produce the best-effort final response value.
|
||||
fn into_final_response(mut self) -> Option<Value> {
|
||||
if self.completed_response.is_some() {
|
||||
return self.completed_response;
|
||||
}
|
||||
|
||||
self.build_fallback_response()
|
||||
}
|
||||
|
||||
fn encountered_error(&self) -> Option<&Value> {
|
||||
self.encountered_error.as_ref()
|
||||
}
|
||||
fn process_block(&mut self, block: &str) {
|
||||
let trimmed = block.trim();
|
||||
if trimmed.is_empty() {
|
||||
return;
|
||||
}
|
||||
|
||||
let mut event_name: Option<String> = None;
|
||||
let mut data_lines: Vec<String> = Vec::new();
|
||||
|
||||
for line in trimmed.lines() {
|
||||
if let Some(rest) = line.strip_prefix("event:") {
|
||||
event_name = Some(rest.trim().to_string());
|
||||
} else if let Some(rest) = line.strip_prefix("data:") {
|
||||
data_lines.push(rest.trim_start().to_string());
|
||||
}
|
||||
}
|
||||
|
||||
let data_payload = data_lines.join("\n");
|
||||
if data_payload.is_empty() {
|
||||
return;
|
||||
}
|
||||
|
||||
self.handle_event(event_name.as_deref(), &data_payload);
|
||||
}
|
||||
|
||||
fn handle_event(&mut self, event_name: Option<&str>, data_payload: &str) {
|
||||
let parsed: Value = match serde_json::from_str(data_payload) {
|
||||
Ok(value) => value,
|
||||
Err(err) => {
|
||||
warn!("Failed to parse streaming event JSON: {}", err);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
let event_type = event_name
|
||||
.map(|s| s.to_string())
|
||||
.or_else(|| {
|
||||
parsed
|
||||
.get("type")
|
||||
.and_then(|v| v.as_str())
|
||||
.map(|s| s.to_string())
|
||||
})
|
||||
.unwrap_or_default();
|
||||
|
||||
match event_type.as_str() {
|
||||
"response.created" => {
|
||||
if let Some(response) = parsed.get("response") {
|
||||
self.initial_response = Some(response.clone());
|
||||
}
|
||||
}
|
||||
"response.completed" => {
|
||||
if let Some(response) = parsed.get("response") {
|
||||
self.completed_response = Some(response.clone());
|
||||
}
|
||||
}
|
||||
"response.output_item.done" => {
|
||||
if let (Some(index), Some(item)) = (
|
||||
parsed
|
||||
.get("output_index")
|
||||
.and_then(|v| v.as_u64())
|
||||
.map(|v| v as usize),
|
||||
parsed.get("item"),
|
||||
) {
|
||||
self.output_items.push((index, item.clone()));
|
||||
}
|
||||
}
|
||||
"response.error" => {
|
||||
self.encountered_error = Some(parsed);
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
fn build_fallback_response(&mut self) -> Option<Value> {
|
||||
let mut response = self.initial_response.clone()?;
|
||||
|
||||
if let Some(obj) = response.as_object_mut() {
|
||||
obj.insert("status".to_string(), Value::String("completed".to_string()));
|
||||
|
||||
self.output_items.sort_by_key(|(index, _)| *index);
|
||||
let outputs: Vec<Value> = self
|
||||
.output_items
|
||||
.iter()
|
||||
.map(|(_, item)| item.clone())
|
||||
.collect();
|
||||
obj.insert("output".to_string(), Value::Array(outputs));
|
||||
}
|
||||
|
||||
Some(response)
|
||||
}
|
||||
}
|
||||
|
||||
impl OpenAIRouter {
|
||||
/// Create a new OpenAI router
|
||||
pub async fn new(
|
||||
@@ -203,17 +342,159 @@ impl OpenAIRouter {
|
||||
|
||||
async fn handle_streaming_response(
|
||||
&self,
|
||||
_url: String,
|
||||
_headers: Option<&HeaderMap>,
|
||||
_payload: Value,
|
||||
_original_body: &ResponsesRequest,
|
||||
_original_previous_response_id: Option<String>,
|
||||
url: String,
|
||||
headers: Option<&HeaderMap>,
|
||||
payload: Value,
|
||||
original_body: &ResponsesRequest,
|
||||
original_previous_response_id: Option<String>,
|
||||
) -> Response {
|
||||
(
|
||||
StatusCode::NOT_IMPLEMENTED,
|
||||
"Streaming responses not yet implemented",
|
||||
)
|
||||
.into_response()
|
||||
let mut request_builder = self.client.post(&url).json(&payload);
|
||||
|
||||
if let Some(headers) = headers {
|
||||
request_builder = apply_request_headers(headers, request_builder, true);
|
||||
}
|
||||
|
||||
request_builder = request_builder.header("Accept", "text/event-stream");
|
||||
|
||||
let response = match request_builder.send().await {
|
||||
Ok(resp) => resp,
|
||||
Err(err) => {
|
||||
self.circuit_breaker.record_failure();
|
||||
return (
|
||||
StatusCode::BAD_GATEWAY,
|
||||
format!("Failed to forward request to OpenAI: {}", err),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
};
|
||||
|
||||
let status = response.status();
|
||||
let status_code =
|
||||
StatusCode::from_u16(status.as_u16()).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
|
||||
|
||||
if !status.is_success() {
|
||||
self.circuit_breaker.record_failure();
|
||||
let error_body = match response.text().await {
|
||||
Ok(body) => body,
|
||||
Err(err) => format!("Failed to read upstream error body: {}", err),
|
||||
};
|
||||
return (status_code, error_body).into_response();
|
||||
}
|
||||
|
||||
self.circuit_breaker.record_success();
|
||||
|
||||
let preserved_headers = preserve_response_headers(response.headers());
|
||||
let mut upstream_stream = response.bytes_stream();
|
||||
|
||||
let (tx, rx) = mpsc::unbounded_channel::<Result<Bytes, io::Error>>();
|
||||
|
||||
let should_store = original_body.store;
|
||||
let storage = self.response_storage.clone();
|
||||
let original_request = original_body.clone();
|
||||
let previous_response_id = original_previous_response_id.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
let mut accumulator = StreamingResponseAccumulator::new();
|
||||
let mut upstream_failed = false;
|
||||
let mut receiver_connected = true;
|
||||
let mut pending = String::new();
|
||||
|
||||
while let Some(chunk_result) = upstream_stream.next().await {
|
||||
match chunk_result {
|
||||
Ok(chunk) => {
|
||||
let chunk_text = match std::str::from_utf8(&chunk) {
|
||||
Ok(text) => Cow::Borrowed(text),
|
||||
Err(_) => Cow::Owned(String::from_utf8_lossy(&chunk).to_string()),
|
||||
};
|
||||
|
||||
pending.push_str(&chunk_text.replace("\r\n", "\n"));
|
||||
|
||||
while let Some(pos) = pending.find("\n\n") {
|
||||
let raw_block = pending[..pos].to_string();
|
||||
pending.drain(..pos + 2);
|
||||
|
||||
if raw_block.trim().is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
let block_cow = if let Some(modified) = Self::rewrite_streaming_block(
|
||||
raw_block.as_str(),
|
||||
&original_request,
|
||||
previous_response_id.as_deref(),
|
||||
) {
|
||||
Cow::Owned(modified)
|
||||
} else {
|
||||
Cow::Borrowed(raw_block.as_str())
|
||||
};
|
||||
|
||||
if should_store {
|
||||
accumulator.ingest_block(block_cow.as_ref());
|
||||
}
|
||||
|
||||
if receiver_connected {
|
||||
let chunk_to_send = format!("{}\n\n", block_cow);
|
||||
if tx.send(Ok(Bytes::from(chunk_to_send))).is_err() {
|
||||
receiver_connected = false;
|
||||
}
|
||||
}
|
||||
|
||||
if !receiver_connected && !should_store {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if !receiver_connected && !should_store {
|
||||
break;
|
||||
}
|
||||
}
|
||||
Err(err) => {
|
||||
upstream_failed = true;
|
||||
let io_err = io::Error::other(err);
|
||||
let _ = tx.send(Err(io_err));
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if should_store && !upstream_failed {
|
||||
if !pending.trim().is_empty() {
|
||||
accumulator.ingest_block(&pending);
|
||||
}
|
||||
let encountered_error = accumulator.encountered_error().cloned();
|
||||
if let Some(mut response_json) = accumulator.into_final_response() {
|
||||
Self::patch_streaming_response_json(
|
||||
&mut response_json,
|
||||
&original_request,
|
||||
previous_response_id.as_deref(),
|
||||
);
|
||||
|
||||
if let Err(err) =
|
||||
Self::store_response_impl(&storage, &response_json, &original_request).await
|
||||
{
|
||||
warn!("Failed to store streaming response: {}", err);
|
||||
}
|
||||
} else if let Some(error_payload) = encountered_error {
|
||||
warn!("Upstream streaming error payload: {}", error_payload);
|
||||
} else {
|
||||
warn!("Streaming completed without a final response payload");
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let body_stream = UnboundedReceiverStream::new(rx);
|
||||
let mut response = Response::new(Body::from_stream(body_stream));
|
||||
*response.status_mut() = status_code;
|
||||
|
||||
let headers_mut = response.headers_mut();
|
||||
for (name, value) in preserved_headers.iter() {
|
||||
headers_mut.insert(name, value.clone());
|
||||
}
|
||||
|
||||
if !headers_mut.contains_key(CONTENT_TYPE) {
|
||||
headers_mut.insert(CONTENT_TYPE, HeaderValue::from_static("text/event-stream"));
|
||||
}
|
||||
|
||||
response
|
||||
}
|
||||
|
||||
async fn store_response_internal(
|
||||
@@ -300,6 +581,172 @@ impl OpenAIRouter {
|
||||
.map_err(|e| format!("Failed to store response: {}", e))
|
||||
}
|
||||
|
||||
fn patch_streaming_response_json(
|
||||
response_json: &mut Value,
|
||||
original_body: &ResponsesRequest,
|
||||
original_previous_response_id: Option<&str>,
|
||||
) {
|
||||
if let Some(obj) = response_json.as_object_mut() {
|
||||
if let Some(prev_id) = original_previous_response_id {
|
||||
let should_insert = obj
|
||||
.get("previous_response_id")
|
||||
.map(|v| v.is_null() || v.as_str().map(|s| s.is_empty()).unwrap_or(false))
|
||||
.unwrap_or(true);
|
||||
if should_insert {
|
||||
obj.insert(
|
||||
"previous_response_id".to_string(),
|
||||
Value::String(prev_id.to_string()),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
if !obj.contains_key("instructions")
|
||||
|| obj
|
||||
.get("instructions")
|
||||
.map(|v| v.is_null())
|
||||
.unwrap_or(false)
|
||||
{
|
||||
if let Some(instructions) = &original_body.instructions {
|
||||
obj.insert(
|
||||
"instructions".to_string(),
|
||||
Value::String(instructions.clone()),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
if !obj.contains_key("metadata")
|
||||
|| obj.get("metadata").map(|v| v.is_null()).unwrap_or(false)
|
||||
{
|
||||
if let Some(metadata) = &original_body.metadata {
|
||||
let metadata_map: serde_json::Map<String, Value> = metadata
|
||||
.iter()
|
||||
.map(|(k, v)| (k.clone(), v.clone()))
|
||||
.collect();
|
||||
obj.insert("metadata".to_string(), Value::Object(metadata_map));
|
||||
}
|
||||
}
|
||||
|
||||
obj.insert("store".to_string(), Value::Bool(original_body.store));
|
||||
|
||||
if obj
|
||||
.get("model")
|
||||
.and_then(|v| v.as_str())
|
||||
.map(|s| s.is_empty())
|
||||
.unwrap_or(true)
|
||||
{
|
||||
if let Some(model) = &original_body.model {
|
||||
obj.insert("model".to_string(), Value::String(model.clone()));
|
||||
}
|
||||
}
|
||||
|
||||
if obj.get("user").map(|v| v.is_null()).unwrap_or(false) {
|
||||
if let Some(user) = &original_body.user {
|
||||
obj.insert("user".to_string(), Value::String(user.clone()));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn rewrite_streaming_block(
|
||||
block: &str,
|
||||
original_body: &ResponsesRequest,
|
||||
original_previous_response_id: Option<&str>,
|
||||
) -> Option<String> {
|
||||
let trimmed = block.trim();
|
||||
if trimmed.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let mut data_lines: Vec<String> = Vec::new();
|
||||
|
||||
for line in trimmed.lines() {
|
||||
if line.starts_with("data:") {
|
||||
data_lines.push(line.trim_start_matches("data:").trim_start().to_string());
|
||||
}
|
||||
}
|
||||
|
||||
if data_lines.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let payload = data_lines.join("\n");
|
||||
let mut parsed: Value = match serde_json::from_str(&payload) {
|
||||
Ok(value) => value,
|
||||
Err(err) => {
|
||||
warn!("Failed to parse streaming JSON payload: {}", err);
|
||||
return None;
|
||||
}
|
||||
};
|
||||
|
||||
let event_type = parsed
|
||||
.get("type")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or_default();
|
||||
|
||||
let should_patch = matches!(
|
||||
event_type,
|
||||
"response.created" | "response.in_progress" | "response.completed"
|
||||
);
|
||||
|
||||
if !should_patch {
|
||||
return None;
|
||||
}
|
||||
|
||||
let mut changed = false;
|
||||
if let Some(response_obj) = parsed.get_mut("response").and_then(|v| v.as_object_mut()) {
|
||||
let desired_store = Value::Bool(original_body.store);
|
||||
if response_obj.get("store") != Some(&desired_store) {
|
||||
response_obj.insert("store".to_string(), desired_store);
|
||||
changed = true;
|
||||
}
|
||||
|
||||
if let Some(prev_id) = original_previous_response_id {
|
||||
let needs_previous = response_obj
|
||||
.get("previous_response_id")
|
||||
.map(|v| v.is_null() || v.as_str().map(|s| s.is_empty()).unwrap_or(false))
|
||||
.unwrap_or(true);
|
||||
|
||||
if needs_previous {
|
||||
response_obj.insert(
|
||||
"previous_response_id".to_string(),
|
||||
Value::String(prev_id.to_string()),
|
||||
);
|
||||
changed = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !changed {
|
||||
return None;
|
||||
}
|
||||
|
||||
let new_payload = match serde_json::to_string(&parsed) {
|
||||
Ok(json) => json,
|
||||
Err(err) => {
|
||||
warn!("Failed to serialize modified streaming payload: {}", err);
|
||||
return None;
|
||||
}
|
||||
};
|
||||
|
||||
let mut rebuilt_lines = Vec::new();
|
||||
let mut data_written = false;
|
||||
for line in trimmed.lines() {
|
||||
if line.starts_with("data:") {
|
||||
if !data_written {
|
||||
rebuilt_lines.push(format!("data: {}", new_payload));
|
||||
data_written = true;
|
||||
}
|
||||
} else {
|
||||
rebuilt_lines.push(line.to_string());
|
||||
}
|
||||
}
|
||||
|
||||
if !data_written {
|
||||
rebuilt_lines.push(format!("data: {}", new_payload));
|
||||
}
|
||||
|
||||
Some(rebuilt_lines.join("\n"))
|
||||
}
|
||||
fn extract_primary_output_text(response_json: &Value) -> Option<String> {
|
||||
if let Some(items) = response_json.get("output").and_then(|v| v.as_array()) {
|
||||
for item in items {
|
||||
@@ -595,14 +1042,6 @@ impl super::super::RouterTrait for OpenAIRouter {
|
||||
"openai_responses_request"
|
||||
);
|
||||
|
||||
if body.stream {
|
||||
return (
|
||||
StatusCode::NOT_IMPLEMENTED,
|
||||
"Streaming responses not yet implemented",
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
|
||||
// Clone the body and override model if needed
|
||||
let mut request_body = body.clone();
|
||||
if let Some(model) = model_id {
|
||||
|
||||
@@ -4,24 +4,27 @@ use axum::{
|
||||
body::Body,
|
||||
extract::Request,
|
||||
http::{Method, StatusCode},
|
||||
response::Response,
|
||||
routing::post,
|
||||
Json, Router,
|
||||
};
|
||||
use serde_json::json;
|
||||
use sglang_router_rs::{
|
||||
config::{RouterConfig, RoutingMode},
|
||||
data_connector::{MemoryResponseStorage, ResponseId, ResponseStorage},
|
||||
data_connector::{MemoryResponseStorage, ResponseId, ResponseStorage, StoredResponse},
|
||||
protocols::spec::{
|
||||
ChatCompletionRequest, ChatMessage, CompletionRequest, GenerateRequest, ResponseInput,
|
||||
ResponsesGetParams, ResponsesRequest, UserMessageContent,
|
||||
},
|
||||
routers::{openai_router::OpenAIRouter, RouterTrait},
|
||||
};
|
||||
use std::collections::HashMap;
|
||||
use std::sync::{
|
||||
atomic::{AtomicUsize, Ordering},
|
||||
Arc,
|
||||
};
|
||||
use tokio::net::TcpListener;
|
||||
use tokio::time::{sleep, Duration};
|
||||
use tower::ServiceExt;
|
||||
|
||||
mod common;
|
||||
@@ -309,6 +312,255 @@ async fn test_openai_router_responses_with_mock() {
|
||||
server.abort();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_openai_router_responses_streaming_with_mock() {
|
||||
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||
let addr = listener.local_addr().unwrap();
|
||||
|
||||
let sse_handler = post(|Json(_request): Json<serde_json::Value>| async move {
|
||||
let response_id = "resp_stream_123";
|
||||
let message_id = "msg_stream_123";
|
||||
let final_text = "Once upon a streamed unicorn adventure.";
|
||||
|
||||
let events = vec![
|
||||
(
|
||||
"response.created",
|
||||
json!({
|
||||
"type": "response.created",
|
||||
"sequence_number": 0,
|
||||
"response": {
|
||||
"id": response_id,
|
||||
"object": "response",
|
||||
"created_at": 1_700_000_500,
|
||||
"status": "in_progress",
|
||||
"model": "",
|
||||
"output": [],
|
||||
"parallel_tool_calls": true,
|
||||
"previous_response_id": null,
|
||||
"reasoning": null,
|
||||
"store": false,
|
||||
"temperature": 1.0,
|
||||
"text": {"format": {"type": "text"}},
|
||||
"tool_choice": "auto",
|
||||
"tools": [],
|
||||
"top_p": 1.0,
|
||||
"truncation": "disabled",
|
||||
"usage": null,
|
||||
"metadata": null
|
||||
}
|
||||
}),
|
||||
),
|
||||
(
|
||||
"response.output_item.added",
|
||||
json!({
|
||||
"type": "response.output_item.added",
|
||||
"sequence_number": 1,
|
||||
"output_index": 0,
|
||||
"item": {
|
||||
"id": message_id,
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"status": "in_progress",
|
||||
"content": []
|
||||
}
|
||||
}),
|
||||
),
|
||||
(
|
||||
"response.output_text.delta",
|
||||
json!({
|
||||
"type": "response.output_text.delta",
|
||||
"sequence_number": 2,
|
||||
"item_id": message_id,
|
||||
"output_index": 0,
|
||||
"content_index": 0,
|
||||
"delta": "Once upon a streamed unicorn adventure.",
|
||||
"logprobs": []
|
||||
}),
|
||||
),
|
||||
(
|
||||
"response.output_text.done",
|
||||
json!({
|
||||
"type": "response.output_text.done",
|
||||
"sequence_number": 3,
|
||||
"item_id": message_id,
|
||||
"output_index": 0,
|
||||
"content_index": 0,
|
||||
"text": final_text,
|
||||
"logprobs": []
|
||||
}),
|
||||
),
|
||||
(
|
||||
"response.output_item.done",
|
||||
json!({
|
||||
"type": "response.output_item.done",
|
||||
"sequence_number": 4,
|
||||
"output_index": 0,
|
||||
"item": {
|
||||
"id": message_id,
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"status": "completed",
|
||||
"content": [{
|
||||
"type": "output_text",
|
||||
"text": final_text,
|
||||
"annotations": [],
|
||||
"logprobs": []
|
||||
}]
|
||||
}
|
||||
}),
|
||||
),
|
||||
(
|
||||
"response.completed",
|
||||
json!({
|
||||
"type": "response.completed",
|
||||
"sequence_number": 5,
|
||||
"response": {
|
||||
"id": response_id,
|
||||
"object": "response",
|
||||
"created_at": 1_700_000_500,
|
||||
"status": "completed",
|
||||
"model": "",
|
||||
"output": [{
|
||||
"id": message_id,
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"status": "completed",
|
||||
"content": [{
|
||||
"type": "output_text",
|
||||
"text": final_text,
|
||||
"annotations": [],
|
||||
"logprobs": []
|
||||
}]
|
||||
}],
|
||||
"parallel_tool_calls": true,
|
||||
"previous_response_id": null,
|
||||
"reasoning": null,
|
||||
"store": false,
|
||||
"temperature": 1.0,
|
||||
"text": {"format": {"type": "text"}},
|
||||
"tool_choice": "auto",
|
||||
"tools": [],
|
||||
"top_p": 1.0,
|
||||
"truncation": "disabled",
|
||||
"usage": {
|
||||
"input_tokens": 10,
|
||||
"input_tokens_details": {"cached_tokens": 0},
|
||||
"output_tokens": 20,
|
||||
"output_tokens_details": {"reasoning_tokens": 5},
|
||||
"total_tokens": 30
|
||||
},
|
||||
"metadata": null,
|
||||
"instructions": null,
|
||||
"user": null
|
||||
}
|
||||
}),
|
||||
),
|
||||
];
|
||||
|
||||
let sse_payload = events
|
||||
.into_iter()
|
||||
.map(|(event, data)| format!("event: {}\ndata: {}\n\n", event, data))
|
||||
.collect::<String>();
|
||||
|
||||
Response::builder()
|
||||
.status(StatusCode::OK)
|
||||
.header("content-type", "text/event-stream")
|
||||
.body(Body::from(sse_payload))
|
||||
.unwrap()
|
||||
});
|
||||
|
||||
let app = Router::new().route("/v1/responses", sse_handler);
|
||||
|
||||
let server = tokio::spawn(async move {
|
||||
axum::serve(listener, app).await.unwrap();
|
||||
});
|
||||
|
||||
let base_url = format!("http://{}", addr);
|
||||
let storage = Arc::new(MemoryResponseStorage::new());
|
||||
|
||||
// Seed a previous response so previous_response_id logic has data to pull from.
|
||||
let mut previous = StoredResponse::new(
|
||||
"Earlier bedtime question".to_string(),
|
||||
"Earlier answer".to_string(),
|
||||
None,
|
||||
);
|
||||
previous.id = ResponseId::from_string("resp_prev_chain".to_string());
|
||||
storage.store_response(previous).await.unwrap();
|
||||
|
||||
let router = OpenAIRouter::new(base_url, None, storage.clone())
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let mut metadata = HashMap::new();
|
||||
metadata.insert("topic".to_string(), json!("unicorns"));
|
||||
|
||||
let request = ResponsesRequest {
|
||||
model: Some("gpt-5-nano".to_string()),
|
||||
input: ResponseInput::Text("Tell me a bedtime story.".to_string()),
|
||||
instructions: Some("Be kind".to_string()),
|
||||
metadata: Some(metadata),
|
||||
previous_response_id: Some("resp_prev_chain".to_string()),
|
||||
store: true,
|
||||
stream: true,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let response = router.route_responses(None, &request, None).await;
|
||||
assert_eq!(response.status(), StatusCode::OK);
|
||||
|
||||
let headers = response.headers();
|
||||
let ct = headers
|
||||
.get("content-type")
|
||||
.unwrap()
|
||||
.to_str()
|
||||
.unwrap()
|
||||
.to_ascii_lowercase();
|
||||
assert!(ct.contains("text/event-stream"));
|
||||
|
||||
let response_body = axum::body::to_bytes(response.into_body(), usize::MAX)
|
||||
.await
|
||||
.unwrap();
|
||||
let body_text = String::from_utf8(response_body.to_vec()).unwrap();
|
||||
assert!(body_text.contains("response.completed"));
|
||||
assert!(body_text.contains("Once upon a streamed unicorn adventure."));
|
||||
|
||||
// Wait for the storage task to persist the streaming response.
|
||||
let target_id = ResponseId::from_string("resp_stream_123".to_string());
|
||||
let stored = loop {
|
||||
if let Some(resp) = storage.get_response(&target_id).await.unwrap() {
|
||||
break resp;
|
||||
}
|
||||
sleep(Duration::from_millis(10)).await;
|
||||
};
|
||||
|
||||
assert_eq!(stored.input, "Tell me a bedtime story.");
|
||||
assert_eq!(stored.output, "Once upon a streamed unicorn adventure.");
|
||||
assert_eq!(
|
||||
stored
|
||||
.previous_response_id
|
||||
.as_ref()
|
||||
.expect("previous_response_id missing")
|
||||
.0,
|
||||
"resp_prev_chain"
|
||||
);
|
||||
assert_eq!(stored.metadata.get("topic"), Some(&json!("unicorns")));
|
||||
assert_eq!(stored.instructions.as_deref(), Some("Be kind"));
|
||||
assert_eq!(stored.model.as_deref(), Some("gpt-5-nano"));
|
||||
assert_eq!(stored.user, None);
|
||||
assert_eq!(stored.raw_response["store"], json!(true));
|
||||
assert_eq!(
|
||||
stored.raw_response["previous_response_id"].as_str(),
|
||||
Some("resp_prev_chain")
|
||||
);
|
||||
assert_eq!(stored.raw_response["metadata"]["topic"], json!("unicorns"));
|
||||
assert_eq!(
|
||||
stored.raw_response["instructions"].as_str(),
|
||||
Some("Be kind")
|
||||
);
|
||||
|
||||
server.abort();
|
||||
}
|
||||
|
||||
/// Test router factory with OpenAI routing mode
|
||||
#[tokio::test]
|
||||
async fn test_router_factory_openai_mode() {
|
||||
|
||||
Reference in New Issue
Block a user