[router] responses api POST and GET with local storage (#10581)
Co-authored-by: key4ng <rukeyang@gmail.com>
This commit is contained in:
@@ -74,13 +74,16 @@ impl ResponseStorage for MemoryResponseStorage {
|
|||||||
|
|
||||||
// Store the response
|
// Store the response
|
||||||
store.responses.insert(response_id.clone(), response);
|
store.responses.insert(response_id.clone(), response);
|
||||||
|
tracing::info!("memory_store_size" = store.responses.len());
|
||||||
|
|
||||||
Ok(response_id)
|
Ok(response_id)
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn get_response(&self, response_id: &ResponseId) -> Result<Option<StoredResponse>> {
|
async fn get_response(&self, response_id: &ResponseId) -> Result<Option<StoredResponse>> {
|
||||||
let store = self.store.read();
|
let store = self.store.read();
|
||||||
Ok(store.responses.get(response_id).cloned())
|
let result = store.responses.get(response_id).cloned();
|
||||||
|
tracing::info!("memory_get_response" = %response_id.0, found = result.is_some());
|
||||||
|
Ok(result)
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn delete_response(&self, response_id: &ResponseId) -> Result<()> {
|
async fn delete_response(&self, response_id: &ResponseId) -> Result<()> {
|
||||||
@@ -200,6 +203,20 @@ pub struct MemoryStoreStats {
|
|||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_store_with_custom_id() {
|
||||||
|
let store = MemoryResponseStorage::new();
|
||||||
|
let mut response = StoredResponse::new("Input".to_string(), "Output".to_string(), None);
|
||||||
|
response.id = ResponseId::from_string("resp_custom".to_string());
|
||||||
|
store.store_response(response.clone()).await.unwrap();
|
||||||
|
let retrieved = store
|
||||||
|
.get_response(&ResponseId::from_string("resp_custom".to_string()))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
assert!(retrieved.is_some());
|
||||||
|
assert_eq!(retrieved.unwrap().output, "Output");
|
||||||
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_memory_store_basic() {
|
async fn test_memory_store_basic() {
|
||||||
let store = MemoryResponseStorage::new();
|
let store = MemoryResponseStorage::new();
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
use serde_json::Value;
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
@@ -55,6 +56,10 @@ pub struct StoredResponse {
|
|||||||
|
|
||||||
/// Model used for generation
|
/// Model used for generation
|
||||||
pub model: Option<String>,
|
pub model: Option<String>,
|
||||||
|
|
||||||
|
/// Raw OpenAI response payload
|
||||||
|
#[serde(default)]
|
||||||
|
pub raw_response: Value,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl StoredResponse {
|
impl StoredResponse {
|
||||||
@@ -70,6 +75,7 @@ impl StoredResponse {
|
|||||||
created_at: chrono::Utc::now(),
|
created_at: chrono::Utc::now(),
|
||||||
user: None,
|
user: None,
|
||||||
model: None,
|
model: None,
|
||||||
|
raw_response: Value::Null,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -175,3 +181,9 @@ pub trait ResponseStorage: Send + Sync {
|
|||||||
|
|
||||||
/// Type alias for shared storage
|
/// Type alias for shared storage
|
||||||
pub type SharedResponseStorage = Arc<dyn ResponseStorage>;
|
pub type SharedResponseStorage = Arc<dyn ResponseStorage>;
|
||||||
|
|
||||||
|
impl Default for StoredResponse {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self::new(String::new(), String::new(), None)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use serde_json::Value;
|
use serde_json::{to_value, Map, Number, Value};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
|
|
||||||
// # Protocol Specifications
|
// # Protocol Specifications
|
||||||
@@ -350,7 +350,7 @@ pub struct ChatCompletionRequest {
|
|||||||
|
|
||||||
/// Session parameters for continual prompting
|
/// Session parameters for continual prompting
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
pub session_params: Option<HashMap<String, serde_json::Value>>,
|
pub session_params: Option<HashMap<String, Value>>,
|
||||||
|
|
||||||
/// Separate reasoning content from final answer (O1-style models)
|
/// Separate reasoning content from final answer (O1-style models)
|
||||||
#[serde(default = "default_true")]
|
#[serde(default = "default_true")]
|
||||||
@@ -362,7 +362,7 @@ pub struct ChatCompletionRequest {
|
|||||||
|
|
||||||
/// Chat template kwargs
|
/// Chat template kwargs
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
pub chat_template_kwargs: Option<HashMap<String, serde_json::Value>>,
|
pub chat_template_kwargs: Option<HashMap<String, Value>>,
|
||||||
|
|
||||||
/// Return model hidden states
|
/// Return model hidden states
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
@@ -447,7 +447,7 @@ pub struct ChatChoice {
|
|||||||
pub finish_reason: Option<String>, // "stop", "length", "tool_calls", "content_filter", "function_call"
|
pub finish_reason: Option<String>, // "stop", "length", "tool_calls", "content_filter", "function_call"
|
||||||
/// Information about which stop condition was matched
|
/// Information about which stop condition was matched
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
pub matched_stop: Option<serde_json::Value>, // Can be string or integer
|
pub matched_stop: Option<Value>, // Can be string or integer
|
||||||
/// Hidden states from the model (SGLang extension)
|
/// Hidden states from the model (SGLang extension)
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
pub hidden_states: Option<Vec<f32>>,
|
pub hidden_states: Option<Vec<f32>>,
|
||||||
@@ -606,7 +606,7 @@ pub struct CompletionRequest {
|
|||||||
|
|
||||||
/// Session parameters for continual prompting
|
/// Session parameters for continual prompting
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
pub session_params: Option<HashMap<String, serde_json::Value>>,
|
pub session_params: Option<HashMap<String, Value>>,
|
||||||
|
|
||||||
/// Return model hidden states
|
/// Return model hidden states
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
@@ -618,7 +618,7 @@ pub struct CompletionRequest {
|
|||||||
|
|
||||||
/// Additional fields including bootstrap info for PD routing
|
/// Additional fields including bootstrap info for PD routing
|
||||||
#[serde(flatten)]
|
#[serde(flatten)]
|
||||||
pub other: serde_json::Map<String, serde_json::Value>,
|
pub other: Map<String, Value>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl GenerationRequest for CompletionRequest {
|
impl GenerationRequest for CompletionRequest {
|
||||||
@@ -662,7 +662,7 @@ pub struct CompletionChoice {
|
|||||||
pub finish_reason: Option<String>, // "stop", "length", "content_filter", etc.
|
pub finish_reason: Option<String>, // "stop", "length", "content_filter", etc.
|
||||||
/// Information about which stop condition was matched
|
/// Information about which stop condition was matched
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
pub matched_stop: Option<serde_json::Value>, // Can be string or integer
|
pub matched_stop: Option<Value>, // Can be string or integer
|
||||||
/// Hidden states from the model (SGLang extension)
|
/// Hidden states from the model (SGLang extension)
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
pub hidden_states: Option<Vec<f32>>,
|
pub hidden_states: Option<Vec<f32>>,
|
||||||
@@ -776,6 +776,10 @@ pub enum ResponseContentPart {
|
|||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
logprobs: Option<ChatLogProbs>,
|
logprobs: Option<ChatLogProbs>,
|
||||||
},
|
},
|
||||||
|
#[serde(rename = "input_text")]
|
||||||
|
InputText { text: String },
|
||||||
|
#[serde(other)]
|
||||||
|
Unknown,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
@@ -864,6 +868,29 @@ pub enum ResponseStatus {
|
|||||||
Cancelled,
|
Cancelled,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ============= Reasoning Info =============
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
|
pub struct ReasoningInfo {
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub effort: Option<String>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub summary: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============= Text Format =============
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
|
pub struct ResponseTextFormat {
|
||||||
|
pub format: TextFormatType,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
|
pub struct TextFormatType {
|
||||||
|
#[serde(rename = "type")]
|
||||||
|
pub format_type: String,
|
||||||
|
}
|
||||||
|
|
||||||
// ============= Include Fields =============
|
// ============= Include Fields =============
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
@@ -915,6 +942,13 @@ pub struct ResponseUsage {
|
|||||||
pub output_tokens_details: Option<OutputTokensDetails>,
|
pub output_tokens_details: Option<OutputTokensDetails>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
|
#[serde(untagged)]
|
||||||
|
pub enum ResponsesUsage {
|
||||||
|
Classic(UsageInfo),
|
||||||
|
Modern(ResponseUsage),
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
pub struct InputTokensDetails {
|
pub struct InputTokensDetails {
|
||||||
pub cached_tokens: u32,
|
pub cached_tokens: u32,
|
||||||
@@ -970,6 +1004,34 @@ impl ResponseUsage {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Default, Deserialize, Serialize)]
|
||||||
|
pub struct ResponsesGetParams {
|
||||||
|
#[serde(default)]
|
||||||
|
pub include: Vec<String>,
|
||||||
|
#[serde(default)]
|
||||||
|
pub include_obfuscation: Option<bool>,
|
||||||
|
#[serde(default)]
|
||||||
|
pub starting_after: Option<i64>,
|
||||||
|
#[serde(default)]
|
||||||
|
pub stream: Option<bool>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ResponsesUsage {
|
||||||
|
pub fn to_response_usage(&self) -> ResponseUsage {
|
||||||
|
match self {
|
||||||
|
ResponsesUsage::Classic(usage) => usage.to_response_usage(),
|
||||||
|
ResponsesUsage::Modern(usage) => usage.clone(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn to_usage_info(&self) -> UsageInfo {
|
||||||
|
match self {
|
||||||
|
ResponsesUsage::Classic(usage) => usage.clone(),
|
||||||
|
ResponsesUsage::Modern(usage) => usage.to_usage_info(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn generate_request_id() -> String {
|
fn generate_request_id() -> String {
|
||||||
format!("resp_{}", uuid::Uuid::new_v4().simple())
|
format!("resp_{}", uuid::Uuid::new_v4().simple())
|
||||||
}
|
}
|
||||||
@@ -1002,7 +1064,7 @@ pub struct ResponsesRequest {
|
|||||||
|
|
||||||
/// Additional metadata
|
/// Additional metadata
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
pub metadata: Option<HashMap<String, serde_json::Value>>,
|
pub metadata: Option<HashMap<String, Value>>,
|
||||||
|
|
||||||
/// Model to use (optional to match vLLM)
|
/// Model to use (optional to match vLLM)
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
@@ -1109,6 +1171,42 @@ fn default_repetition_penalty() -> f32 {
|
|||||||
1.0
|
1.0
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl Default for ResponsesRequest {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
background: false,
|
||||||
|
include: None,
|
||||||
|
input: ResponseInput::Text(String::new()),
|
||||||
|
instructions: None,
|
||||||
|
max_output_tokens: None,
|
||||||
|
max_tool_calls: None,
|
||||||
|
metadata: None,
|
||||||
|
model: None,
|
||||||
|
parallel_tool_calls: true,
|
||||||
|
previous_response_id: None,
|
||||||
|
reasoning: None,
|
||||||
|
service_tier: ServiceTier::default(),
|
||||||
|
store: true,
|
||||||
|
stream: false,
|
||||||
|
temperature: None,
|
||||||
|
tool_choice: ToolChoice::default(),
|
||||||
|
tools: Vec::new(),
|
||||||
|
top_logprobs: 0,
|
||||||
|
top_p: None,
|
||||||
|
truncation: Truncation::default(),
|
||||||
|
user: None,
|
||||||
|
request_id: generate_request_id(),
|
||||||
|
priority: 0,
|
||||||
|
frequency_penalty: 0.0,
|
||||||
|
presence_penalty: 0.0,
|
||||||
|
stop: None,
|
||||||
|
top_k: default_top_k(),
|
||||||
|
min_p: 0.0,
|
||||||
|
repetition_penalty: default_repetition_penalty(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl ResponsesRequest {
|
impl ResponsesRequest {
|
||||||
/// Default sampling parameters
|
/// Default sampling parameters
|
||||||
const DEFAULT_TEMPERATURE: f32 = 0.7;
|
const DEFAULT_TEMPERATURE: f32 = 0.7;
|
||||||
@@ -1118,8 +1216,8 @@ impl ResponsesRequest {
|
|||||||
pub fn to_sampling_params(
|
pub fn to_sampling_params(
|
||||||
&self,
|
&self,
|
||||||
default_max_tokens: u32,
|
default_max_tokens: u32,
|
||||||
default_params: Option<HashMap<String, serde_json::Value>>,
|
default_params: Option<HashMap<String, Value>>,
|
||||||
) -> HashMap<String, serde_json::Value> {
|
) -> HashMap<String, Value> {
|
||||||
let mut params = HashMap::new();
|
let mut params = HashMap::new();
|
||||||
|
|
||||||
// Use max_output_tokens if available
|
// Use max_output_tokens if available
|
||||||
@@ -1154,47 +1252,38 @@ impl ResponsesRequest {
|
|||||||
|
|
||||||
params.insert(
|
params.insert(
|
||||||
"max_new_tokens".to_string(),
|
"max_new_tokens".to_string(),
|
||||||
serde_json::Value::Number(serde_json::Number::from(max_tokens)),
|
Value::Number(Number::from(max_tokens)),
|
||||||
);
|
);
|
||||||
params.insert(
|
params.insert(
|
||||||
"temperature".to_string(),
|
"temperature".to_string(),
|
||||||
serde_json::Value::Number(serde_json::Number::from_f64(temperature as f64).unwrap()),
|
Value::Number(Number::from_f64(temperature as f64).unwrap()),
|
||||||
);
|
);
|
||||||
params.insert(
|
params.insert(
|
||||||
"top_p".to_string(),
|
"top_p".to_string(),
|
||||||
serde_json::Value::Number(serde_json::Number::from_f64(top_p as f64).unwrap()),
|
Value::Number(Number::from_f64(top_p as f64).unwrap()),
|
||||||
);
|
);
|
||||||
params.insert(
|
params.insert(
|
||||||
"frequency_penalty".to_string(),
|
"frequency_penalty".to_string(),
|
||||||
serde_json::Value::Number(
|
Value::Number(Number::from_f64(self.frequency_penalty as f64).unwrap()),
|
||||||
serde_json::Number::from_f64(self.frequency_penalty as f64).unwrap(),
|
|
||||||
),
|
|
||||||
);
|
);
|
||||||
params.insert(
|
params.insert(
|
||||||
"presence_penalty".to_string(),
|
"presence_penalty".to_string(),
|
||||||
serde_json::Value::Number(
|
Value::Number(Number::from_f64(self.presence_penalty as f64).unwrap()),
|
||||||
serde_json::Number::from_f64(self.presence_penalty as f64).unwrap(),
|
|
||||||
),
|
|
||||||
);
|
|
||||||
params.insert(
|
|
||||||
"top_k".to_string(),
|
|
||||||
serde_json::Value::Number(serde_json::Number::from(self.top_k)),
|
|
||||||
);
|
);
|
||||||
|
params.insert("top_k".to_string(), Value::Number(Number::from(self.top_k)));
|
||||||
params.insert(
|
params.insert(
|
||||||
"min_p".to_string(),
|
"min_p".to_string(),
|
||||||
serde_json::Value::Number(serde_json::Number::from_f64(self.min_p as f64).unwrap()),
|
Value::Number(Number::from_f64(self.min_p as f64).unwrap()),
|
||||||
);
|
);
|
||||||
params.insert(
|
params.insert(
|
||||||
"repetition_penalty".to_string(),
|
"repetition_penalty".to_string(),
|
||||||
serde_json::Value::Number(
|
Value::Number(Number::from_f64(self.repetition_penalty as f64).unwrap()),
|
||||||
serde_json::Number::from_f64(self.repetition_penalty as f64).unwrap(),
|
|
||||||
),
|
|
||||||
);
|
);
|
||||||
|
|
||||||
if let Some(ref stop) = self.stop {
|
if let Some(ref stop) = self.stop {
|
||||||
match serde_json::to_value(stop) {
|
match to_value(stop) {
|
||||||
Ok(value) => params.insert("stop".to_string(), value),
|
Ok(value) => params.insert("stop".to_string(), value),
|
||||||
Err(_) => params.insert("stop".to_string(), serde_json::Value::Null),
|
Err(_) => params.insert("stop".to_string(), Value::Null),
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1227,8 +1316,10 @@ impl GenerationRequest for ResponsesRequest {
|
|||||||
ResponseInputOutputItem::Message { content, .. } => {
|
ResponseInputOutputItem::Message { content, .. } => {
|
||||||
let texts: Vec<String> = content
|
let texts: Vec<String> = content
|
||||||
.iter()
|
.iter()
|
||||||
.map(|part| match part {
|
.filter_map(|part| match part {
|
||||||
ResponseContentPart::OutputText { text, .. } => text.clone(),
|
ResponseContentPart::OutputText { text, .. } => Some(text.clone()),
|
||||||
|
ResponseContentPart::InputText { text } => Some(text.clone()),
|
||||||
|
ResponseContentPart::Unknown => None,
|
||||||
})
|
})
|
||||||
.collect();
|
.collect();
|
||||||
if texts.is_empty() {
|
if texts.is_empty() {
|
||||||
@@ -1285,6 +1376,25 @@ pub struct ResponsesResponse {
|
|||||||
#[serde(default = "current_timestamp")]
|
#[serde(default = "current_timestamp")]
|
||||||
pub created_at: i64,
|
pub created_at: i64,
|
||||||
|
|
||||||
|
/// Response status
|
||||||
|
pub status: ResponseStatus,
|
||||||
|
|
||||||
|
/// Error information if status is failed
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub error: Option<Value>,
|
||||||
|
|
||||||
|
/// Incomplete details if response was truncated
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub incomplete_details: Option<Value>,
|
||||||
|
|
||||||
|
/// System instructions used
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub instructions: Option<String>,
|
||||||
|
|
||||||
|
/// Max output tokens setting
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub max_output_tokens: Option<u32>,
|
||||||
|
|
||||||
/// Model name
|
/// Model name
|
||||||
pub model: String,
|
pub model: String,
|
||||||
|
|
||||||
@@ -1292,17 +1402,30 @@ pub struct ResponsesResponse {
|
|||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
pub output: Vec<ResponseOutputItem>,
|
pub output: Vec<ResponseOutputItem>,
|
||||||
|
|
||||||
/// Response status
|
|
||||||
pub status: ResponseStatus,
|
|
||||||
|
|
||||||
/// Usage information
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub usage: Option<UsageInfo>,
|
|
||||||
|
|
||||||
/// Whether parallel tool calls are enabled
|
/// Whether parallel tool calls are enabled
|
||||||
#[serde(default = "default_true")]
|
#[serde(default = "default_true")]
|
||||||
pub parallel_tool_calls: bool,
|
pub parallel_tool_calls: bool,
|
||||||
|
|
||||||
|
/// Previous response ID if this is a continuation
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub previous_response_id: Option<String>,
|
||||||
|
|
||||||
|
/// Reasoning information
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub reasoning: Option<ReasoningInfo>,
|
||||||
|
|
||||||
|
/// Whether the response is stored
|
||||||
|
#[serde(default = "default_true")]
|
||||||
|
pub store: bool,
|
||||||
|
|
||||||
|
/// Temperature setting used
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub temperature: Option<f32>,
|
||||||
|
|
||||||
|
/// Text format settings
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub text: Option<ResponseTextFormat>,
|
||||||
|
|
||||||
/// Tool choice setting
|
/// Tool choice setting
|
||||||
#[serde(default = "default_tool_choice")]
|
#[serde(default = "default_tool_choice")]
|
||||||
pub tool_choice: String,
|
pub tool_choice: String,
|
||||||
@@ -1310,6 +1433,26 @@ pub struct ResponsesResponse {
|
|||||||
/// Available tools
|
/// Available tools
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
pub tools: Vec<ResponseTool>,
|
pub tools: Vec<ResponseTool>,
|
||||||
|
|
||||||
|
/// Top-p setting used
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub top_p: Option<f32>,
|
||||||
|
|
||||||
|
/// Truncation strategy used
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub truncation: Option<String>,
|
||||||
|
|
||||||
|
/// Usage information
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub usage: Option<ResponsesUsage>,
|
||||||
|
|
||||||
|
/// User identifier
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub user: Option<String>,
|
||||||
|
|
||||||
|
/// Additional metadata
|
||||||
|
#[serde(default)]
|
||||||
|
pub metadata: HashMap<String, Value>,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn default_object_type() -> String {
|
fn default_object_type() -> String {
|
||||||
@@ -1325,7 +1468,7 @@ impl ResponsesResponse {
|
|||||||
#[allow(clippy::too_many_arguments)]
|
#[allow(clippy::too_many_arguments)]
|
||||||
pub fn from_request(
|
pub fn from_request(
|
||||||
request: &ResponsesRequest,
|
request: &ResponsesRequest,
|
||||||
_sampling_params: &HashMap<String, serde_json::Value>,
|
_sampling_params: &HashMap<String, Value>,
|
||||||
model_name: String,
|
model_name: String,
|
||||||
created_time: i64,
|
created_time: i64,
|
||||||
output: Vec<ResponseOutputItem>,
|
output: Vec<ResponseOutputItem>,
|
||||||
@@ -1336,11 +1479,26 @@ impl ResponsesResponse {
|
|||||||
id: request.request_id.clone(),
|
id: request.request_id.clone(),
|
||||||
object: "response".to_string(),
|
object: "response".to_string(),
|
||||||
created_at: created_time,
|
created_at: created_time,
|
||||||
|
status,
|
||||||
|
error: None,
|
||||||
|
incomplete_details: None,
|
||||||
|
instructions: request.instructions.clone(),
|
||||||
|
max_output_tokens: request.max_output_tokens,
|
||||||
model: model_name,
|
model: model_name,
|
||||||
output,
|
output,
|
||||||
status,
|
|
||||||
usage,
|
|
||||||
parallel_tool_calls: request.parallel_tool_calls,
|
parallel_tool_calls: request.parallel_tool_calls,
|
||||||
|
previous_response_id: request.previous_response_id.clone(),
|
||||||
|
reasoning: request.reasoning.as_ref().map(|r| ReasoningInfo {
|
||||||
|
effort: r.effort.as_ref().map(|e| format!("{:?}", e)),
|
||||||
|
summary: None,
|
||||||
|
}),
|
||||||
|
store: request.store,
|
||||||
|
temperature: request.temperature,
|
||||||
|
text: Some(ResponseTextFormat {
|
||||||
|
format: TextFormatType {
|
||||||
|
format_type: "text".to_string(),
|
||||||
|
},
|
||||||
|
}),
|
||||||
tool_choice: match &request.tool_choice {
|
tool_choice: match &request.tool_choice {
|
||||||
ToolChoice::Value(ToolChoiceValue::Auto) => "auto".to_string(),
|
ToolChoice::Value(ToolChoiceValue::Auto) => "auto".to_string(),
|
||||||
ToolChoice::Value(ToolChoiceValue::Required) => "required".to_string(),
|
ToolChoice::Value(ToolChoiceValue::Required) => "required".to_string(),
|
||||||
@@ -1348,6 +1506,14 @@ impl ResponsesResponse {
|
|||||||
ToolChoice::Function { .. } => "function".to_string(),
|
ToolChoice::Function { .. } => "function".to_string(),
|
||||||
},
|
},
|
||||||
tools: request.tools.clone(),
|
tools: request.tools.clone(),
|
||||||
|
top_p: request.top_p,
|
||||||
|
truncation: match &request.truncation {
|
||||||
|
Truncation::Auto => Some("auto".to_string()),
|
||||||
|
Truncation::Disabled => Some("disabled".to_string()),
|
||||||
|
},
|
||||||
|
usage: usage.map(ResponsesUsage::Classic),
|
||||||
|
user: request.user.clone(),
|
||||||
|
metadata: request.metadata.clone().unwrap_or_default(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1357,13 +1523,26 @@ impl ResponsesResponse {
|
|||||||
id: request_id,
|
id: request_id,
|
||||||
object: "response".to_string(),
|
object: "response".to_string(),
|
||||||
created_at: current_timestamp(),
|
created_at: current_timestamp(),
|
||||||
|
status,
|
||||||
|
error: None,
|
||||||
|
incomplete_details: None,
|
||||||
|
instructions: None,
|
||||||
|
max_output_tokens: None,
|
||||||
model,
|
model,
|
||||||
output: Vec::new(),
|
output: Vec::new(),
|
||||||
status,
|
|
||||||
usage: None,
|
|
||||||
parallel_tool_calls: true,
|
parallel_tool_calls: true,
|
||||||
|
previous_response_id: None,
|
||||||
|
reasoning: None,
|
||||||
|
store: true,
|
||||||
|
temperature: None,
|
||||||
|
text: None,
|
||||||
tool_choice: "auto".to_string(),
|
tool_choice: "auto".to_string(),
|
||||||
tools: Vec::new(),
|
tools: Vec::new(),
|
||||||
|
top_p: None,
|
||||||
|
truncation: None,
|
||||||
|
usage: None,
|
||||||
|
user: None,
|
||||||
|
metadata: HashMap::new(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1374,7 +1553,7 @@ impl ResponsesResponse {
|
|||||||
|
|
||||||
/// Set the usage information
|
/// Set the usage information
|
||||||
pub fn set_usage(&mut self, usage: UsageInfo) {
|
pub fn set_usage(&mut self, usage: UsageInfo) {
|
||||||
self.usage = Some(usage);
|
self.usage = Some(ResponsesUsage::Classic(usage));
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Update the status
|
/// Update the status
|
||||||
@@ -1413,12 +1592,12 @@ impl ResponsesResponse {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Get the response as a JSON value with usage in response format
|
/// Get the response as a JSON value with usage in response format
|
||||||
pub fn to_response_format(&self) -> serde_json::Value {
|
pub fn to_response_format(&self) -> Value {
|
||||||
let mut response = serde_json::to_value(self).unwrap_or(serde_json::Value::Null);
|
let mut response = to_value(self).unwrap_or(Value::Null);
|
||||||
|
|
||||||
// Convert usage to response format if present
|
// Convert usage to response format if present
|
||||||
if let Some(usage) = &self.usage {
|
if let Some(usage) = &self.usage {
|
||||||
if let Ok(usage_value) = serde_json::to_value(usage.to_response_usage()) {
|
if let Ok(usage_value) = to_value(usage.to_response_usage()) {
|
||||||
response["usage"] = usage_value;
|
response["usage"] = usage_value;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -1641,8 +1820,13 @@ pub struct LogProbs {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
pub struct ChatLogProbs {
|
#[serde(untagged)]
|
||||||
pub content: Option<Vec<ChatLogProbsContent>>,
|
pub enum ChatLogProbs {
|
||||||
|
Detailed {
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
content: Option<Vec<ChatLogProbsContent>>,
|
||||||
|
},
|
||||||
|
Raw(Value),
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
@@ -1798,7 +1982,7 @@ pub struct GenerateRequest {
|
|||||||
|
|
||||||
/// Session parameters for continual prompting
|
/// Session parameters for continual prompting
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
pub session_params: Option<HashMap<String, serde_json::Value>>,
|
pub session_params: Option<HashMap<String, Value>>,
|
||||||
|
|
||||||
/// Return model hidden states
|
/// Return model hidden states
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
@@ -2065,7 +2249,7 @@ pub struct EmbeddingRequest {
|
|||||||
pub model: String,
|
pub model: String,
|
||||||
|
|
||||||
/// Input can be a string, array of strings, tokens, or batch inputs
|
/// Input can be a string, array of strings, tokens, or batch inputs
|
||||||
pub input: serde_json::Value,
|
pub input: Value,
|
||||||
|
|
||||||
/// Optional encoding format (e.g., "float", "base64")
|
/// Optional encoding format (e.g., "float", "base64")
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
@@ -2097,8 +2281,8 @@ impl GenerationRequest for EmbeddingRequest {
|
|||||||
fn extract_text_for_routing(&self) -> String {
|
fn extract_text_for_routing(&self) -> String {
|
||||||
// Best effort: extract text content for routing decisions
|
// Best effort: extract text content for routing decisions
|
||||||
match &self.input {
|
match &self.input {
|
||||||
serde_json::Value::String(s) => s.clone(),
|
Value::String(s) => s.clone(),
|
||||||
serde_json::Value::Array(arr) => arr
|
Value::Array(arr) => arr
|
||||||
.iter()
|
.iter()
|
||||||
.filter_map(|v| v.as_str())
|
.filter_map(|v| v.as_str())
|
||||||
.collect::<Vec<_>>()
|
.collect::<Vec<_>>()
|
||||||
@@ -2173,7 +2357,7 @@ pub enum LoRAPath {
|
|||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
use serde_json;
|
use serde_json::{from_str, json, to_string};
|
||||||
|
|
||||||
// ==================================================================
|
// ==================================================================
|
||||||
// = RERANK REQUEST TESTS =
|
// = RERANK REQUEST TESTS =
|
||||||
@@ -2191,8 +2375,8 @@ mod tests {
|
|||||||
user: Some("user-456".to_string()),
|
user: Some("user-456".to_string()),
|
||||||
};
|
};
|
||||||
|
|
||||||
let serialized = serde_json::to_string(&request).unwrap();
|
let serialized = to_string(&request).unwrap();
|
||||||
let deserialized: RerankRequest = serde_json::from_str(&serialized).unwrap();
|
let deserialized: RerankRequest = from_str(&serialized).unwrap();
|
||||||
|
|
||||||
assert_eq!(deserialized.query, request.query);
|
assert_eq!(deserialized.query, request.query);
|
||||||
assert_eq!(deserialized.documents, request.documents);
|
assert_eq!(deserialized.documents, request.documents);
|
||||||
@@ -2210,7 +2394,7 @@ mod tests {
|
|||||||
"documents": ["doc1", "doc2"]
|
"documents": ["doc1", "doc2"]
|
||||||
}"#;
|
}"#;
|
||||||
|
|
||||||
let request: RerankRequest = serde_json::from_str(json).unwrap();
|
let request: RerankRequest = from_str(json).unwrap();
|
||||||
|
|
||||||
assert_eq!(request.query, "test query");
|
assert_eq!(request.query, "test query");
|
||||||
assert_eq!(request.documents, vec!["doc1", "doc2"]);
|
assert_eq!(request.documents, vec!["doc1", "doc2"]);
|
||||||
@@ -2402,8 +2586,8 @@ mod tests {
|
|||||||
Some(StringOrArray::String("req-123".to_string())),
|
Some(StringOrArray::String("req-123".to_string())),
|
||||||
);
|
);
|
||||||
|
|
||||||
let serialized = serde_json::to_string(&response).unwrap();
|
let serialized = to_string(&response).unwrap();
|
||||||
let deserialized: RerankResponse = serde_json::from_str(&serialized).unwrap();
|
let deserialized: RerankResponse = from_str(&serialized).unwrap();
|
||||||
|
|
||||||
assert_eq!(deserialized.results.len(), response.results.len());
|
assert_eq!(deserialized.results.len(), response.results.len());
|
||||||
assert_eq!(deserialized.model, response.model);
|
assert_eq!(deserialized.model, response.model);
|
||||||
@@ -2539,13 +2723,13 @@ mod tests {
|
|||||||
("confidence".to_string(), Value::String("high".to_string())),
|
("confidence".to_string(), Value::String("high".to_string())),
|
||||||
(
|
(
|
||||||
"processing_time".to_string(),
|
"processing_time".to_string(),
|
||||||
Value::Number(serde_json::Number::from(150)),
|
Value::Number(Number::from(150)),
|
||||||
),
|
),
|
||||||
])),
|
])),
|
||||||
};
|
};
|
||||||
|
|
||||||
let serialized = serde_json::to_string(&result).unwrap();
|
let serialized = to_string(&result).unwrap();
|
||||||
let deserialized: RerankResult = serde_json::from_str(&serialized).unwrap();
|
let deserialized: RerankResult = from_str(&serialized).unwrap();
|
||||||
|
|
||||||
assert_eq!(deserialized.score, result.score);
|
assert_eq!(deserialized.score, result.score);
|
||||||
assert_eq!(deserialized.document, result.document);
|
assert_eq!(deserialized.document, result.document);
|
||||||
@@ -2562,8 +2746,8 @@ mod tests {
|
|||||||
meta_info: None,
|
meta_info: None,
|
||||||
};
|
};
|
||||||
|
|
||||||
let serialized = serde_json::to_string(&result).unwrap();
|
let serialized = to_string(&result).unwrap();
|
||||||
let deserialized: RerankResult = serde_json::from_str(&serialized).unwrap();
|
let deserialized: RerankResult = from_str(&serialized).unwrap();
|
||||||
|
|
||||||
assert_eq!(deserialized.score, result.score);
|
assert_eq!(deserialized.score, result.score);
|
||||||
assert_eq!(deserialized.document, result.document);
|
assert_eq!(deserialized.document, result.document);
|
||||||
@@ -2582,8 +2766,8 @@ mod tests {
|
|||||||
documents: vec!["doc1".to_string(), "doc2".to_string()],
|
documents: vec!["doc1".to_string(), "doc2".to_string()],
|
||||||
};
|
};
|
||||||
|
|
||||||
let serialized = serde_json::to_string(&v1_input).unwrap();
|
let serialized = to_string(&v1_input).unwrap();
|
||||||
let deserialized: V1RerankReqInput = serde_json::from_str(&serialized).unwrap();
|
let deserialized: V1RerankReqInput = from_str(&serialized).unwrap();
|
||||||
|
|
||||||
assert_eq!(deserialized.query, v1_input.query);
|
assert_eq!(deserialized.query, v1_input.query);
|
||||||
assert_eq!(deserialized.documents, v1_input.documents);
|
assert_eq!(deserialized.documents, v1_input.documents);
|
||||||
@@ -2724,8 +2908,8 @@ mod tests {
|
|||||||
prompt_tokens_details: None,
|
prompt_tokens_details: None,
|
||||||
});
|
});
|
||||||
|
|
||||||
let serialized = serde_json::to_string(&response).unwrap();
|
let serialized = to_string(&response).unwrap();
|
||||||
let deserialized: RerankResponse = serde_json::from_str(&serialized).unwrap();
|
let deserialized: RerankResponse = from_str(&serialized).unwrap();
|
||||||
|
|
||||||
assert!(deserialized.usage.is_some());
|
assert!(deserialized.usage.is_some());
|
||||||
let usage = deserialized.usage.unwrap();
|
let usage = deserialized.usage.unwrap();
|
||||||
@@ -2805,8 +2989,8 @@ mod tests {
|
|||||||
assert_eq!(response.model, "rerank-model");
|
assert_eq!(response.model, "rerank-model");
|
||||||
|
|
||||||
// Serialize and deserialize
|
// Serialize and deserialize
|
||||||
let serialized = serde_json::to_string(&response).unwrap();
|
let serialized = to_string(&response).unwrap();
|
||||||
let deserialized: RerankResponse = serde_json::from_str(&serialized).unwrap();
|
let deserialized: RerankResponse = from_str(&serialized).unwrap();
|
||||||
assert_eq!(deserialized.results.len(), 2);
|
assert_eq!(deserialized.results.len(), 2);
|
||||||
assert_eq!(deserialized.model, response.model);
|
assert_eq!(deserialized.model, response.model);
|
||||||
}
|
}
|
||||||
@@ -2819,15 +3003,15 @@ mod tests {
|
|||||||
fn test_embedding_request_serialization_string_input() {
|
fn test_embedding_request_serialization_string_input() {
|
||||||
let req = EmbeddingRequest {
|
let req = EmbeddingRequest {
|
||||||
model: "test-emb".to_string(),
|
model: "test-emb".to_string(),
|
||||||
input: serde_json::Value::String("hello".to_string()),
|
input: Value::String("hello".to_string()),
|
||||||
encoding_format: Some("float".to_string()),
|
encoding_format: Some("float".to_string()),
|
||||||
user: Some("user-1".to_string()),
|
user: Some("user-1".to_string()),
|
||||||
dimensions: Some(128),
|
dimensions: Some(128),
|
||||||
rid: Some("rid-123".to_string()),
|
rid: Some("rid-123".to_string()),
|
||||||
};
|
};
|
||||||
|
|
||||||
let serialized = serde_json::to_string(&req).unwrap();
|
let serialized = to_string(&req).unwrap();
|
||||||
let deserialized: EmbeddingRequest = serde_json::from_str(&serialized).unwrap();
|
let deserialized: EmbeddingRequest = from_str(&serialized).unwrap();
|
||||||
|
|
||||||
assert_eq!(deserialized.model, req.model);
|
assert_eq!(deserialized.model, req.model);
|
||||||
assert_eq!(deserialized.input, req.input);
|
assert_eq!(deserialized.input, req.input);
|
||||||
@@ -2841,15 +3025,15 @@ mod tests {
|
|||||||
fn test_embedding_request_serialization_array_input() {
|
fn test_embedding_request_serialization_array_input() {
|
||||||
let req = EmbeddingRequest {
|
let req = EmbeddingRequest {
|
||||||
model: "test-emb".to_string(),
|
model: "test-emb".to_string(),
|
||||||
input: serde_json::json!(["a", "b", "c"]),
|
input: json!(["a", "b", "c"]),
|
||||||
encoding_format: None,
|
encoding_format: None,
|
||||||
user: None,
|
user: None,
|
||||||
dimensions: None,
|
dimensions: None,
|
||||||
rid: None,
|
rid: None,
|
||||||
};
|
};
|
||||||
|
|
||||||
let serialized = serde_json::to_string(&req).unwrap();
|
let serialized = to_string(&req).unwrap();
|
||||||
let de: EmbeddingRequest = serde_json::from_str(&serialized).unwrap();
|
let de: EmbeddingRequest = from_str(&serialized).unwrap();
|
||||||
assert_eq!(de.model, req.model);
|
assert_eq!(de.model, req.model);
|
||||||
assert_eq!(de.input, req.input);
|
assert_eq!(de.input, req.input);
|
||||||
}
|
}
|
||||||
@@ -2858,7 +3042,7 @@ mod tests {
|
|||||||
fn test_embedding_generation_request_trait_string() {
|
fn test_embedding_generation_request_trait_string() {
|
||||||
let req = EmbeddingRequest {
|
let req = EmbeddingRequest {
|
||||||
model: "emb-model".to_string(),
|
model: "emb-model".to_string(),
|
||||||
input: serde_json::Value::String("hello".to_string()),
|
input: Value::String("hello".to_string()),
|
||||||
encoding_format: None,
|
encoding_format: None,
|
||||||
user: None,
|
user: None,
|
||||||
dimensions: None,
|
dimensions: None,
|
||||||
@@ -2873,7 +3057,7 @@ mod tests {
|
|||||||
fn test_embedding_generation_request_trait_array() {
|
fn test_embedding_generation_request_trait_array() {
|
||||||
let req = EmbeddingRequest {
|
let req = EmbeddingRequest {
|
||||||
model: "emb-model".to_string(),
|
model: "emb-model".to_string(),
|
||||||
input: serde_json::json!(["hello", "world"]),
|
input: json!(["hello", "world"]),
|
||||||
encoding_format: None,
|
encoding_format: None,
|
||||||
user: None,
|
user: None,
|
||||||
dimensions: None,
|
dimensions: None,
|
||||||
@@ -2886,7 +3070,7 @@ mod tests {
|
|||||||
fn test_embedding_generation_request_trait_non_text() {
|
fn test_embedding_generation_request_trait_non_text() {
|
||||||
let req = EmbeddingRequest {
|
let req = EmbeddingRequest {
|
||||||
model: "emb-model".to_string(),
|
model: "emb-model".to_string(),
|
||||||
input: serde_json::json!({"tokens": [1, 2, 3]}),
|
input: json!({"tokens": [1, 2, 3]}),
|
||||||
encoding_format: None,
|
encoding_format: None,
|
||||||
user: None,
|
user: None,
|
||||||
dimensions: None,
|
dimensions: None,
|
||||||
@@ -2899,7 +3083,7 @@ mod tests {
|
|||||||
fn test_embedding_generation_request_trait_mixed_array_ignores_nested() {
|
fn test_embedding_generation_request_trait_mixed_array_ignores_nested() {
|
||||||
let req = EmbeddingRequest {
|
let req = EmbeddingRequest {
|
||||||
model: "emb-model".to_string(),
|
model: "emb-model".to_string(),
|
||||||
input: serde_json::json!(["a", ["b", "c"], 123, {"k": "v"}]),
|
input: json!(["a", ["b", "c"], 123, {"k": "v"}]),
|
||||||
encoding_format: None,
|
encoding_format: None,
|
||||||
user: None,
|
user: None,
|
||||||
dimensions: None,
|
dimensions: None,
|
||||||
|
|||||||
@@ -166,8 +166,12 @@ impl RouterFactory {
|
|||||||
.cloned()
|
.cloned()
|
||||||
.ok_or_else(|| "OpenAI mode requires at least one worker URL".to_string())?;
|
.ok_or_else(|| "OpenAI mode requires at least one worker URL".to_string())?;
|
||||||
|
|
||||||
let router =
|
let router = OpenAIRouter::new(
|
||||||
OpenAIRouter::new(base_url, Some(ctx.router_config.circuit_breaker.clone())).await?;
|
base_url,
|
||||||
|
Some(ctx.router_config.circuit_breaker.clone()),
|
||||||
|
ctx.response_storage.clone(),
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
|
||||||
Ok(Box::new(router))
|
Ok(Box::new(router))
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -308,7 +308,12 @@ impl RouterTrait for GrpcPDRouter {
|
|||||||
(StatusCode::NOT_IMPLEMENTED).into_response()
|
(StatusCode::NOT_IMPLEMENTED).into_response()
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn get_response(&self, _headers: Option<&HeaderMap>, _response_id: &str) -> Response {
|
async fn get_response(
|
||||||
|
&self,
|
||||||
|
_headers: Option<&HeaderMap>,
|
||||||
|
_response_id: &str,
|
||||||
|
_params: &crate::protocols::spec::ResponsesGetParams,
|
||||||
|
) -> Response {
|
||||||
(StatusCode::NOT_IMPLEMENTED).into_response()
|
(StatusCode::NOT_IMPLEMENTED).into_response()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -237,7 +237,12 @@ impl RouterTrait for GrpcRouter {
|
|||||||
(StatusCode::NOT_IMPLEMENTED).into_response()
|
(StatusCode::NOT_IMPLEMENTED).into_response()
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn get_response(&self, _headers: Option<&HeaderMap>, _response_id: &str) -> Response {
|
async fn get_response(
|
||||||
|
&self,
|
||||||
|
_headers: Option<&HeaderMap>,
|
||||||
|
_response_id: &str,
|
||||||
|
_params: &crate::protocols::spec::ResponsesGetParams,
|
||||||
|
) -> Response {
|
||||||
(StatusCode::NOT_IMPLEMENTED).into_response()
|
(StatusCode::NOT_IMPLEMENTED).into_response()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -51,3 +51,45 @@ fn should_forward_header(name: &str) -> bool {
|
|||||||
"host" // Should not forward the backend's host header
|
"host" // Should not forward the backend's host header
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Apply headers to a reqwest request builder, filtering out headers that shouldn't be forwarded
|
||||||
|
/// or that will be set automatically by reqwest
|
||||||
|
pub fn apply_request_headers(
|
||||||
|
headers: &HeaderMap,
|
||||||
|
mut request_builder: reqwest::RequestBuilder,
|
||||||
|
skip_content_headers: bool,
|
||||||
|
) -> reqwest::RequestBuilder {
|
||||||
|
// Always forward Authorization header first if present
|
||||||
|
if let Some(auth) = headers
|
||||||
|
.get("authorization")
|
||||||
|
.or_else(|| headers.get("Authorization"))
|
||||||
|
{
|
||||||
|
request_builder = request_builder.header("Authorization", auth.clone());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Forward other headers, filtering out problematic ones
|
||||||
|
for (key, value) in headers.iter() {
|
||||||
|
let key_str = key.as_str().to_lowercase();
|
||||||
|
|
||||||
|
// Skip headers that:
|
||||||
|
// - Are set automatically by reqwest (content-type, content-length for POST/PUT)
|
||||||
|
// - We already handled (authorization)
|
||||||
|
// - Are hop-by-hop headers (connection, transfer-encoding)
|
||||||
|
// - Should not be forwarded (host)
|
||||||
|
let should_skip = key_str == "authorization" || // Already handled above
|
||||||
|
key_str == "host" ||
|
||||||
|
key_str == "connection" ||
|
||||||
|
key_str == "transfer-encoding" ||
|
||||||
|
key_str == "keep-alive" ||
|
||||||
|
key_str == "te" ||
|
||||||
|
key_str == "trailers" ||
|
||||||
|
key_str == "upgrade" ||
|
||||||
|
(skip_content_headers && (key_str == "content-type" || key_str == "content-length"));
|
||||||
|
|
||||||
|
if !should_skip {
|
||||||
|
request_builder = request_builder.header(key.clone(), value.clone());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
request_builder
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,10 +1,15 @@
|
|||||||
//! OpenAI router implementation (reqwest-based)
|
//! OpenAI router implementation
|
||||||
|
|
||||||
use crate::config::CircuitBreakerConfig;
|
use crate::config::CircuitBreakerConfig;
|
||||||
use crate::core::{CircuitBreaker, CircuitBreakerConfig as CoreCircuitBreakerConfig};
|
use crate::core::{CircuitBreaker, CircuitBreakerConfig as CoreCircuitBreakerConfig};
|
||||||
|
use crate::data_connector::{ResponseId, SharedResponseStorage, StoredResponse};
|
||||||
use crate::protocols::spec::{
|
use crate::protocols::spec::{
|
||||||
ChatCompletionRequest, CompletionRequest, GenerateRequest, RerankRequest,
|
ChatCompletionRequest, CompletionRequest, EmbeddingRequest, GenerateRequest, RerankRequest,
|
||||||
|
ResponseContentPart, ResponseInput, ResponseInputOutputItem, ResponseOutputItem,
|
||||||
|
ResponseStatus, ResponseTextFormat, ResponsesGetParams, ResponsesRequest, ResponsesResponse,
|
||||||
|
TextFormatType,
|
||||||
};
|
};
|
||||||
|
use crate::routers::header_utils::{apply_request_headers, preserve_response_headers};
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use axum::{
|
use axum::{
|
||||||
body::Body,
|
body::Body,
|
||||||
@@ -13,13 +18,17 @@ use axum::{
|
|||||||
response::{IntoResponse, Response},
|
response::{IntoResponse, Response},
|
||||||
};
|
};
|
||||||
use futures_util::StreamExt;
|
use futures_util::StreamExt;
|
||||||
|
use serde_json::{json, to_value, Value};
|
||||||
use std::{
|
use std::{
|
||||||
any::Any,
|
any::Any,
|
||||||
|
collections::HashMap,
|
||||||
sync::atomic::{AtomicBool, Ordering},
|
sync::atomic::{AtomicBool, Ordering},
|
||||||
};
|
};
|
||||||
|
use tokio::sync::mpsc;
|
||||||
|
use tokio_stream::wrappers::UnboundedReceiverStream;
|
||||||
|
use tracing::{error, info, warn};
|
||||||
|
|
||||||
/// Router for OpenAI backend
|
/// Router for OpenAI backend
|
||||||
#[derive(Debug)]
|
|
||||||
pub struct OpenAIRouter {
|
pub struct OpenAIRouter {
|
||||||
/// HTTP client for upstream OpenAI-compatible API
|
/// HTTP client for upstream OpenAI-compatible API
|
||||||
client: reqwest::Client,
|
client: reqwest::Client,
|
||||||
@@ -29,6 +38,17 @@ pub struct OpenAIRouter {
|
|||||||
circuit_breaker: CircuitBreaker,
|
circuit_breaker: CircuitBreaker,
|
||||||
/// Health status
|
/// Health status
|
||||||
healthy: AtomicBool,
|
healthy: AtomicBool,
|
||||||
|
/// Response storage for managing conversation history
|
||||||
|
response_storage: SharedResponseStorage,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::fmt::Debug for OpenAIRouter {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
f.debug_struct("OpenAIRouter")
|
||||||
|
.field("base_url", &self.base_url)
|
||||||
|
.field("healthy", &self.healthy)
|
||||||
|
.finish()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl OpenAIRouter {
|
impl OpenAIRouter {
|
||||||
@@ -36,6 +56,7 @@ impl OpenAIRouter {
|
|||||||
pub async fn new(
|
pub async fn new(
|
||||||
base_url: String,
|
base_url: String,
|
||||||
circuit_breaker_config: Option<CircuitBreakerConfig>,
|
circuit_breaker_config: Option<CircuitBreakerConfig>,
|
||||||
|
response_storage: SharedResponseStorage,
|
||||||
) -> Result<Self, String> {
|
) -> Result<Self, String> {
|
||||||
let client = reqwest::Client::builder()
|
let client = reqwest::Client::builder()
|
||||||
.timeout(std::time::Duration::from_secs(300))
|
.timeout(std::time::Duration::from_secs(300))
|
||||||
@@ -61,8 +82,246 @@ impl OpenAIRouter {
|
|||||||
base_url,
|
base_url,
|
||||||
circuit_breaker,
|
circuit_breaker,
|
||||||
healthy: AtomicBool::new(true),
|
healthy: AtomicBool::new(true),
|
||||||
|
response_storage,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn handle_non_streaming_response(
|
||||||
|
&self,
|
||||||
|
url: String,
|
||||||
|
headers: Option<&HeaderMap>,
|
||||||
|
payload: Value,
|
||||||
|
original_body: &ResponsesRequest,
|
||||||
|
original_previous_response_id: Option<String>,
|
||||||
|
) -> Response {
|
||||||
|
let request_builder = self.client.post(&url).json(&payload);
|
||||||
|
|
||||||
|
// Apply headers with filtering
|
||||||
|
let request_builder = if let Some(headers) = headers {
|
||||||
|
apply_request_headers(headers, request_builder, true)
|
||||||
|
} else {
|
||||||
|
request_builder
|
||||||
|
};
|
||||||
|
|
||||||
|
match request_builder.send().await {
|
||||||
|
Ok(response) => {
|
||||||
|
let status = response.status();
|
||||||
|
|
||||||
|
if !status.is_success() {
|
||||||
|
let error_text = response
|
||||||
|
.text()
|
||||||
|
.await
|
||||||
|
.unwrap_or_else(|e| format!("Failed to get error body: {}", e));
|
||||||
|
return (status, error_text).into_response();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse the response
|
||||||
|
match response.json::<Value>().await {
|
||||||
|
Ok(mut openai_response_json) => {
|
||||||
|
if let Some(prev_id) = original_previous_response_id {
|
||||||
|
if let Some(obj) = openai_response_json.as_object_mut() {
|
||||||
|
let should_insert = obj
|
||||||
|
.get("previous_response_id")
|
||||||
|
.map(|v| v.is_null())
|
||||||
|
.unwrap_or(true);
|
||||||
|
if should_insert {
|
||||||
|
obj.insert(
|
||||||
|
"previous_response_id".to_string(),
|
||||||
|
Value::String(prev_id),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(obj) = openai_response_json.as_object_mut() {
|
||||||
|
if !obj.contains_key("instructions") {
|
||||||
|
if let Some(instructions) = &original_body.instructions {
|
||||||
|
obj.insert(
|
||||||
|
"instructions".to_string(),
|
||||||
|
Value::String(instructions.clone()),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !obj.contains_key("metadata") {
|
||||||
|
if let Some(metadata) = &original_body.metadata {
|
||||||
|
let metadata_map: serde_json::Map<String, Value> = metadata
|
||||||
|
.iter()
|
||||||
|
.map(|(k, v)| (k.clone(), v.clone()))
|
||||||
|
.collect();
|
||||||
|
obj.insert("metadata".to_string(), Value::Object(metadata_map));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reflect the client's requested store preference in the response body
|
||||||
|
obj.insert("store".to_string(), Value::Bool(original_body.store));
|
||||||
|
}
|
||||||
|
|
||||||
|
if original_body.store {
|
||||||
|
if let Err(e) = self
|
||||||
|
.store_response_internal(&openai_response_json, original_body)
|
||||||
|
.await
|
||||||
|
{
|
||||||
|
warn!("Failed to store response: {}", e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
match serde_json::to_string(&openai_response_json) {
|
||||||
|
Ok(json_str) => (
|
||||||
|
StatusCode::OK,
|
||||||
|
[("content-type", "application/json")],
|
||||||
|
json_str,
|
||||||
|
)
|
||||||
|
.into_response(),
|
||||||
|
Err(e) => {
|
||||||
|
error!("Failed to serialize response: {}", e);
|
||||||
|
(
|
||||||
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
|
json!({"error": {"message": "Failed to serialize response", "type": "internal_error"}}).to_string(),
|
||||||
|
)
|
||||||
|
.into_response()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
error!("Failed to parse OpenAI response: {}", e);
|
||||||
|
(
|
||||||
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
|
format!("Failed to parse response: {}", e),
|
||||||
|
)
|
||||||
|
.into_response()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(e) => (
|
||||||
|
StatusCode::BAD_GATEWAY,
|
||||||
|
format!("Failed to forward request to OpenAI: {}", e),
|
||||||
|
)
|
||||||
|
.into_response(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn handle_streaming_response(
|
||||||
|
&self,
|
||||||
|
_url: String,
|
||||||
|
_headers: Option<&HeaderMap>,
|
||||||
|
_payload: Value,
|
||||||
|
_original_body: &ResponsesRequest,
|
||||||
|
_original_previous_response_id: Option<String>,
|
||||||
|
) -> Response {
|
||||||
|
(
|
||||||
|
StatusCode::NOT_IMPLEMENTED,
|
||||||
|
"Streaming responses not yet implemented",
|
||||||
|
)
|
||||||
|
.into_response()
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn store_response_internal(
|
||||||
|
&self,
|
||||||
|
response_json: &Value,
|
||||||
|
original_body: &ResponsesRequest,
|
||||||
|
) -> Result<(), String> {
|
||||||
|
if !original_body.store {
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
|
||||||
|
match Self::store_response_impl(&self.response_storage, response_json, original_body).await
|
||||||
|
{
|
||||||
|
Ok(response_id) => {
|
||||||
|
info!(response_id = %response_id.0, "Stored response locally");
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
Err(e) => Err(e),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn store_response_impl(
|
||||||
|
response_storage: &SharedResponseStorage,
|
||||||
|
response_json: &Value,
|
||||||
|
original_body: &ResponsesRequest,
|
||||||
|
) -> Result<ResponseId, String> {
|
||||||
|
let input_text = match &original_body.input {
|
||||||
|
ResponseInput::Text(text) => text.clone(),
|
||||||
|
ResponseInput::Items(_) => "complex input".to_string(),
|
||||||
|
};
|
||||||
|
|
||||||
|
let output_text = Self::extract_primary_output_text(response_json).unwrap_or_default();
|
||||||
|
|
||||||
|
let mut stored_response = StoredResponse::new(input_text, output_text, None);
|
||||||
|
|
||||||
|
stored_response.instructions = response_json
|
||||||
|
.get("instructions")
|
||||||
|
.and_then(|v| v.as_str())
|
||||||
|
.map(|s| s.to_string())
|
||||||
|
.or_else(|| original_body.instructions.clone());
|
||||||
|
|
||||||
|
stored_response.model = response_json
|
||||||
|
.get("model")
|
||||||
|
.and_then(|v| v.as_str())
|
||||||
|
.map(|s| s.to_string())
|
||||||
|
.or_else(|| original_body.model.clone());
|
||||||
|
|
||||||
|
stored_response.user = response_json
|
||||||
|
.get("user")
|
||||||
|
.and_then(|v| v.as_str())
|
||||||
|
.map(|s| s.to_string())
|
||||||
|
.or_else(|| original_body.user.clone());
|
||||||
|
|
||||||
|
stored_response.metadata = response_json
|
||||||
|
.get("metadata")
|
||||||
|
.and_then(|v| v.as_object())
|
||||||
|
.map(|m| {
|
||||||
|
m.iter()
|
||||||
|
.map(|(k, v)| (k.clone(), v.clone()))
|
||||||
|
.collect::<HashMap<_, _>>()
|
||||||
|
})
|
||||||
|
.unwrap_or_else(|| original_body.metadata.clone().unwrap_or_default());
|
||||||
|
|
||||||
|
stored_response.previous_response_id = response_json
|
||||||
|
.get("previous_response_id")
|
||||||
|
.and_then(|v| v.as_str())
|
||||||
|
.map(|s| ResponseId::from_string(s.to_string()))
|
||||||
|
.or_else(|| {
|
||||||
|
original_body
|
||||||
|
.previous_response_id
|
||||||
|
.as_ref()
|
||||||
|
.map(|id| ResponseId::from_string(id.clone()))
|
||||||
|
});
|
||||||
|
|
||||||
|
if let Some(id_str) = response_json.get("id").and_then(|v| v.as_str()) {
|
||||||
|
stored_response.id = ResponseId::from_string(id_str.to_string());
|
||||||
|
}
|
||||||
|
|
||||||
|
stored_response.raw_response = response_json.clone();
|
||||||
|
|
||||||
|
response_storage
|
||||||
|
.store_response(stored_response)
|
||||||
|
.await
|
||||||
|
.map_err(|e| format!("Failed to store response: {}", e))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn extract_primary_output_text(response_json: &Value) -> Option<String> {
|
||||||
|
if let Some(items) = response_json.get("output").and_then(|v| v.as_array()) {
|
||||||
|
for item in items {
|
||||||
|
if let Some(content) = item.get("content").and_then(|v| v.as_array()) {
|
||||||
|
for part in content {
|
||||||
|
if part
|
||||||
|
.get("type")
|
||||||
|
.and_then(|v| v.as_str())
|
||||||
|
.map(|t| t == "output_text")
|
||||||
|
.unwrap_or(false)
|
||||||
|
{
|
||||||
|
if let Some(text) = part.get("text").and_then(|v| v.as_str()) {
|
||||||
|
return Some(text.to_string());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
None
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
@@ -108,7 +367,7 @@ impl super::super::RouterTrait for OpenAIRouter {
|
|||||||
}
|
}
|
||||||
|
|
||||||
async fn get_server_info(&self, _req: Request<Body>) -> Response {
|
async fn get_server_info(&self, _req: Request<Body>) -> Response {
|
||||||
let info = serde_json::json!({
|
let info = json!({
|
||||||
"router_type": "openai",
|
"router_type": "openai",
|
||||||
"workers": 1,
|
"workers": 1,
|
||||||
"base_url": &self.base_url
|
"base_url": &self.base_url
|
||||||
@@ -192,7 +451,7 @@ impl super::super::RouterTrait for OpenAIRouter {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Serialize request body, removing SGLang-only fields
|
// Serialize request body, removing SGLang-only fields
|
||||||
let mut payload = match serde_json::to_value(body) {
|
let mut payload = match to_value(body) {
|
||||||
Ok(v) => v,
|
Ok(v) => v,
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
return (
|
return (
|
||||||
@@ -282,7 +541,7 @@ impl super::super::RouterTrait for OpenAIRouter {
|
|||||||
} else {
|
} else {
|
||||||
// Stream SSE bytes to client
|
// Stream SSE bytes to client
|
||||||
let stream = resp.bytes_stream();
|
let stream = resp.bytes_stream();
|
||||||
let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
|
let (tx, rx) = mpsc::unbounded_channel();
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
let mut s = stream;
|
let mut s = stream;
|
||||||
while let Some(chunk) = s.next().await {
|
while let Some(chunk) = s.next().await {
|
||||||
@@ -299,9 +558,7 @@ impl super::super::RouterTrait for OpenAIRouter {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
let mut response = Response::new(Body::from_stream(
|
let mut response = Response::new(Body::from_stream(UnboundedReceiverStream::new(rx)));
|
||||||
tokio_stream::wrappers::UnboundedReceiverStream::new(rx),
|
|
||||||
));
|
|
||||||
*response.status_mut() = status;
|
*response.status_mut() = status;
|
||||||
response
|
response
|
||||||
.headers_mut()
|
.headers_mut()
|
||||||
@@ -326,36 +583,294 @@ impl super::super::RouterTrait for OpenAIRouter {
|
|||||||
|
|
||||||
async fn route_responses(
|
async fn route_responses(
|
||||||
&self,
|
&self,
|
||||||
_headers: Option<&HeaderMap>,
|
headers: Option<&HeaderMap>,
|
||||||
_body: &crate::protocols::spec::ResponsesRequest,
|
body: &ResponsesRequest,
|
||||||
_model_id: Option<&str>,
|
model_id: Option<&str>,
|
||||||
) -> Response {
|
) -> Response {
|
||||||
|
let url = format!("{}/v1/responses", self.base_url);
|
||||||
|
|
||||||
|
info!(
|
||||||
|
requested_store = body.store,
|
||||||
|
is_streaming = body.stream,
|
||||||
|
"openai_responses_request"
|
||||||
|
);
|
||||||
|
|
||||||
|
if body.stream {
|
||||||
|
return (
|
||||||
|
StatusCode::NOT_IMPLEMENTED,
|
||||||
|
"Streaming responses not yet implemented",
|
||||||
|
)
|
||||||
|
.into_response();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clone the body and override model if needed
|
||||||
|
let mut request_body = body.clone();
|
||||||
|
if let Some(model) = model_id {
|
||||||
|
request_body.model = Some(model.to_string());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Store the original previous_response_id for the response
|
||||||
|
let original_previous_response_id = request_body.previous_response_id.clone();
|
||||||
|
|
||||||
|
// Handle previous_response_id by loading prior context
|
||||||
|
let mut conversation_items: Option<Vec<ResponseInputOutputItem>> = None;
|
||||||
|
if let Some(prev_id_str) = request_body.previous_response_id.clone() {
|
||||||
|
let prev_id = ResponseId::from_string(prev_id_str.clone());
|
||||||
|
match self
|
||||||
|
.response_storage
|
||||||
|
.get_response_chain(&prev_id, None)
|
||||||
|
.await
|
||||||
|
{
|
||||||
|
Ok(chain) => {
|
||||||
|
if !chain.responses.is_empty() {
|
||||||
|
let mut items = Vec::new();
|
||||||
|
for stored in chain.responses.iter() {
|
||||||
|
let trimmed_id = stored.id.0.trim_start_matches("resp_");
|
||||||
|
if !stored.input.is_empty() {
|
||||||
|
items.push(ResponseInputOutputItem::Message {
|
||||||
|
id: format!("msg_u_{}", trimmed_id),
|
||||||
|
role: "user".to_string(),
|
||||||
|
status: Some("completed".to_string()),
|
||||||
|
content: vec![ResponseContentPart::InputText {
|
||||||
|
text: stored.input.clone(),
|
||||||
|
}],
|
||||||
|
});
|
||||||
|
}
|
||||||
|
if !stored.output.is_empty() {
|
||||||
|
items.push(ResponseInputOutputItem::Message {
|
||||||
|
id: format!("msg_a_{}", trimmed_id),
|
||||||
|
role: "assistant".to_string(),
|
||||||
|
status: Some("completed".to_string()),
|
||||||
|
content: vec![ResponseContentPart::OutputText {
|
||||||
|
text: stored.output.clone(),
|
||||||
|
annotations: vec![],
|
||||||
|
logprobs: None,
|
||||||
|
}],
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
conversation_items = Some(items);
|
||||||
|
} else {
|
||||||
|
info!(previous_response_id = %prev_id_str, "previous chain empty");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(err) => {
|
||||||
|
warn!(previous_response_id = %prev_id_str, %err, "failed to fetch previous response chain");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Clear previous_response_id from request since we're converting to conversation
|
||||||
|
request_body.previous_response_id = None;
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(mut items) = conversation_items {
|
||||||
|
match &request_body.input {
|
||||||
|
ResponseInput::Text(text) => {
|
||||||
|
items.push(ResponseInputOutputItem::Message {
|
||||||
|
id: format!("msg_u_current_{}", items.len()),
|
||||||
|
role: "user".to_string(),
|
||||||
|
status: Some("completed".to_string()),
|
||||||
|
content: vec![ResponseContentPart::InputText { text: text.clone() }],
|
||||||
|
});
|
||||||
|
}
|
||||||
|
ResponseInput::Items(existing) => {
|
||||||
|
items.extend(existing.clone());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
request_body.input = ResponseInput::Items(items);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Always set store=false for OpenAI (we store internally)
|
||||||
|
request_body.store = false;
|
||||||
|
|
||||||
|
// Convert to JSON payload and strip SGLang-specific fields before forwarding
|
||||||
|
let mut payload = match to_value(&request_body) {
|
||||||
|
Ok(value) => value,
|
||||||
|
Err(err) => {
|
||||||
|
return (
|
||||||
|
StatusCode::BAD_REQUEST,
|
||||||
|
format!("Failed to serialize responses request: {}", err),
|
||||||
|
)
|
||||||
|
.into_response();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
if let Some(obj) = payload.as_object_mut() {
|
||||||
|
for key in [
|
||||||
|
"request_id",
|
||||||
|
"priority",
|
||||||
|
"frequency_penalty",
|
||||||
|
"presence_penalty",
|
||||||
|
"stop",
|
||||||
|
"top_k",
|
||||||
|
"min_p",
|
||||||
|
"repetition_penalty",
|
||||||
|
] {
|
||||||
|
obj.remove(key);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if streaming is requested
|
||||||
|
if body.stream {
|
||||||
|
// Handle streaming response
|
||||||
|
self.handle_streaming_response(
|
||||||
|
url,
|
||||||
|
headers,
|
||||||
|
payload,
|
||||||
|
body,
|
||||||
|
original_previous_response_id,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
} else {
|
||||||
|
// Handle non-streaming response
|
||||||
|
self.handle_non_streaming_response(
|
||||||
|
url,
|
||||||
|
headers,
|
||||||
|
payload,
|
||||||
|
body,
|
||||||
|
original_previous_response_id,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn get_response(
|
||||||
|
&self,
|
||||||
|
_headers: Option<&HeaderMap>,
|
||||||
|
response_id: &str,
|
||||||
|
params: &ResponsesGetParams,
|
||||||
|
) -> Response {
|
||||||
|
let stored_id = ResponseId::from_string(response_id.to_string());
|
||||||
|
if let Ok(Some(stored_response)) = self.response_storage.get_response(&stored_id).await {
|
||||||
|
let stream_requested = params.stream.unwrap_or(false);
|
||||||
|
let raw_value = stored_response.raw_response.clone();
|
||||||
|
|
||||||
|
if !raw_value.is_null() {
|
||||||
|
if stream_requested {
|
||||||
|
return (
|
||||||
|
StatusCode::NOT_IMPLEMENTED,
|
||||||
|
"Streaming retrieval not yet implemented",
|
||||||
|
)
|
||||||
|
.into_response();
|
||||||
|
}
|
||||||
|
|
||||||
|
return (
|
||||||
|
StatusCode::OK,
|
||||||
|
[("content-type", "application/json")],
|
||||||
|
raw_value.to_string(),
|
||||||
|
)
|
||||||
|
.into_response();
|
||||||
|
}
|
||||||
|
|
||||||
|
let openai_response = ResponsesResponse {
|
||||||
|
id: stored_response.id.0.clone(),
|
||||||
|
object: "response".to_string(),
|
||||||
|
created_at: stored_response.created_at.timestamp(),
|
||||||
|
status: ResponseStatus::Completed,
|
||||||
|
error: None,
|
||||||
|
incomplete_details: None,
|
||||||
|
instructions: stored_response.instructions.clone(),
|
||||||
|
max_output_tokens: None,
|
||||||
|
model: stored_response
|
||||||
|
.model
|
||||||
|
.unwrap_or_else(|| "gpt-4o".to_string()),
|
||||||
|
output: vec![ResponseOutputItem::Message {
|
||||||
|
id: format!("msg_{}", stored_response.id.0),
|
||||||
|
role: "assistant".to_string(),
|
||||||
|
status: "completed".to_string(),
|
||||||
|
content: vec![ResponseContentPart::OutputText {
|
||||||
|
text: stored_response.output,
|
||||||
|
annotations: vec![],
|
||||||
|
logprobs: None,
|
||||||
|
}],
|
||||||
|
}],
|
||||||
|
parallel_tool_calls: true,
|
||||||
|
previous_response_id: stored_response.previous_response_id.map(|id| id.0),
|
||||||
|
reasoning: None,
|
||||||
|
store: true,
|
||||||
|
temperature: Some(1.0),
|
||||||
|
text: Some(ResponseTextFormat {
|
||||||
|
format: TextFormatType {
|
||||||
|
format_type: "text".to_string(),
|
||||||
|
},
|
||||||
|
}),
|
||||||
|
tool_choice: "auto".to_string(),
|
||||||
|
tools: vec![],
|
||||||
|
top_p: Some(1.0),
|
||||||
|
truncation: Some("disabled".to_string()),
|
||||||
|
usage: None,
|
||||||
|
user: stored_response.user.clone(),
|
||||||
|
metadata: stored_response.metadata.clone(),
|
||||||
|
};
|
||||||
|
|
||||||
|
if stream_requested {
|
||||||
|
return (
|
||||||
|
StatusCode::NOT_IMPLEMENTED,
|
||||||
|
"Streaming retrieval not yet implemented",
|
||||||
|
)
|
||||||
|
.into_response();
|
||||||
|
}
|
||||||
|
|
||||||
|
return (
|
||||||
|
StatusCode::OK,
|
||||||
|
[("content-type", "application/json")],
|
||||||
|
serde_json::to_string(&openai_response).unwrap_or_else(|e| {
|
||||||
|
format!("{{\"error\": \"Failed to serialize response: {}\"}}", e)
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
.into_response();
|
||||||
|
}
|
||||||
|
|
||||||
(
|
(
|
||||||
StatusCode::NOT_IMPLEMENTED,
|
StatusCode::NOT_FOUND,
|
||||||
"Responses endpoint not implemented for OpenAI router",
|
format!(
|
||||||
|
"Response with id '{}' not found in local storage",
|
||||||
|
response_id
|
||||||
|
),
|
||||||
)
|
)
|
||||||
.into_response()
|
.into_response()
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn get_response(&self, _headers: Option<&HeaderMap>, _response_id: &str) -> Response {
|
async fn cancel_response(&self, headers: Option<&HeaderMap>, response_id: &str) -> Response {
|
||||||
(
|
// Forward to OpenAI's cancel endpoint
|
||||||
StatusCode::NOT_IMPLEMENTED,
|
let url = format!("{}/v1/responses/{}/cancel", self.base_url, response_id);
|
||||||
"Responses retrieve endpoint not implemented for OpenAI router",
|
|
||||||
)
|
|
||||||
.into_response()
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn cancel_response(&self, _headers: Option<&HeaderMap>, _response_id: &str) -> Response {
|
let request_builder = self.client.post(&url);
|
||||||
(
|
|
||||||
StatusCode::NOT_IMPLEMENTED,
|
// Apply headers with filtering (skip content headers for POST without body)
|
||||||
"Responses cancel endpoint not implemented for OpenAI router",
|
let request_builder = if let Some(headers) = headers {
|
||||||
)
|
apply_request_headers(headers, request_builder, true)
|
||||||
.into_response()
|
} else {
|
||||||
|
request_builder
|
||||||
|
};
|
||||||
|
|
||||||
|
match request_builder.send().await {
|
||||||
|
Ok(response) => {
|
||||||
|
let status = response.status();
|
||||||
|
let headers = response.headers().clone();
|
||||||
|
|
||||||
|
match response.text().await {
|
||||||
|
Ok(body_text) => {
|
||||||
|
let mut response = (status, body_text).into_response();
|
||||||
|
*response.headers_mut() = preserve_response_headers(&headers);
|
||||||
|
response
|
||||||
|
}
|
||||||
|
Err(e) => (
|
||||||
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
|
format!("Failed to read response body: {}", e),
|
||||||
|
)
|
||||||
|
.into_response(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(e) => (
|
||||||
|
StatusCode::BAD_GATEWAY,
|
||||||
|
format!("Failed to cancel response on OpenAI: {}", e),
|
||||||
|
)
|
||||||
|
.into_response(),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn flush_cache(&self) -> Response {
|
async fn flush_cache(&self) -> Response {
|
||||||
(
|
(
|
||||||
StatusCode::NOT_IMPLEMENTED,
|
StatusCode::FORBIDDEN,
|
||||||
"flush_cache not supported for OpenAI router",
|
"flush_cache not supported for OpenAI router",
|
||||||
)
|
)
|
||||||
.into_response()
|
.into_response()
|
||||||
@@ -363,7 +878,7 @@ impl super::super::RouterTrait for OpenAIRouter {
|
|||||||
|
|
||||||
async fn get_worker_loads(&self) -> Response {
|
async fn get_worker_loads(&self) -> Response {
|
||||||
(
|
(
|
||||||
StatusCode::NOT_IMPLEMENTED,
|
StatusCode::FORBIDDEN,
|
||||||
"get_worker_loads not supported for OpenAI router",
|
"get_worker_loads not supported for OpenAI router",
|
||||||
)
|
)
|
||||||
.into_response()
|
.into_response()
|
||||||
@@ -384,12 +899,12 @@ impl super::super::RouterTrait for OpenAIRouter {
|
|||||||
async fn route_embeddings(
|
async fn route_embeddings(
|
||||||
&self,
|
&self,
|
||||||
_headers: Option<&HeaderMap>,
|
_headers: Option<&HeaderMap>,
|
||||||
_body: &crate::protocols::spec::EmbeddingRequest,
|
_body: &EmbeddingRequest,
|
||||||
_model_id: Option<&str>,
|
_model_id: Option<&str>,
|
||||||
) -> Response {
|
) -> Response {
|
||||||
(
|
(
|
||||||
StatusCode::NOT_IMPLEMENTED,
|
StatusCode::FORBIDDEN,
|
||||||
"Embeddings endpoint not implemented for OpenAI backend",
|
"Embeddings endpoint not supported for OpenAI backend",
|
||||||
)
|
)
|
||||||
.into_response()
|
.into_response()
|
||||||
}
|
}
|
||||||
@@ -401,8 +916,8 @@ impl super::super::RouterTrait for OpenAIRouter {
|
|||||||
_model_id: Option<&str>,
|
_model_id: Option<&str>,
|
||||||
) -> Response {
|
) -> Response {
|
||||||
(
|
(
|
||||||
StatusCode::NOT_IMPLEMENTED,
|
StatusCode::FORBIDDEN,
|
||||||
"Rerank endpoint not implemented for OpenAI backend",
|
"Rerank endpoint not supported for OpenAI backend",
|
||||||
)
|
)
|
||||||
.into_response()
|
.into_response()
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ use crate::metrics::RouterMetrics;
|
|||||||
use crate::policies::{LoadBalancingPolicy, PolicyRegistry};
|
use crate::policies::{LoadBalancingPolicy, PolicyRegistry};
|
||||||
use crate::protocols::spec::{
|
use crate::protocols::spec::{
|
||||||
ChatCompletionRequest, ChatMessage, CompletionRequest, GenerateRequest, RerankRequest,
|
ChatCompletionRequest, ChatMessage, CompletionRequest, GenerateRequest, RerankRequest,
|
||||||
ResponsesRequest, StringOrArray, UserMessageContent,
|
ResponsesGetParams, ResponsesRequest, StringOrArray, UserMessageContent,
|
||||||
};
|
};
|
||||||
use crate::routers::header_utils;
|
use crate::routers::header_utils;
|
||||||
use crate::routers::RouterTrait;
|
use crate::routers::RouterTrait;
|
||||||
@@ -1424,7 +1424,12 @@ impl RouterTrait for PDRouter {
|
|||||||
.into_response()
|
.into_response()
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn get_response(&self, _headers: Option<&HeaderMap>, _response_id: &str) -> Response {
|
async fn get_response(
|
||||||
|
&self,
|
||||||
|
_headers: Option<&HeaderMap>,
|
||||||
|
_response_id: &str,
|
||||||
|
_params: &ResponsesGetParams,
|
||||||
|
) -> Response {
|
||||||
(
|
(
|
||||||
StatusCode::NOT_IMPLEMENTED,
|
StatusCode::NOT_IMPLEMENTED,
|
||||||
"Responses retrieve endpoint not implemented for PD router",
|
"Responses retrieve endpoint not implemented for PD router",
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ use crate::metrics::RouterMetrics;
|
|||||||
use crate::policies::{LoadBalancingPolicy, PolicyRegistry};
|
use crate::policies::{LoadBalancingPolicy, PolicyRegistry};
|
||||||
use crate::protocols::spec::{
|
use crate::protocols::spec::{
|
||||||
ChatCompletionRequest, CompletionRequest, EmbeddingRequest, GenerateRequest, GenerationRequest,
|
ChatCompletionRequest, CompletionRequest, EmbeddingRequest, GenerateRequest, GenerationRequest,
|
||||||
RerankRequest, RerankResponse, RerankResult, ResponsesRequest,
|
RerankRequest, RerankResponse, RerankResult, ResponsesGetParams, ResponsesRequest,
|
||||||
};
|
};
|
||||||
use crate::routers::header_utils;
|
use crate::routers::header_utils;
|
||||||
use crate::routers::RouterTrait;
|
use crate::routers::RouterTrait;
|
||||||
@@ -903,7 +903,12 @@ impl RouterTrait for Router {
|
|||||||
.await
|
.await
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn get_response(&self, headers: Option<&HeaderMap>, response_id: &str) -> Response {
|
async fn get_response(
|
||||||
|
&self,
|
||||||
|
headers: Option<&HeaderMap>,
|
||||||
|
response_id: &str,
|
||||||
|
_params: &ResponsesGetParams,
|
||||||
|
) -> Response {
|
||||||
let endpoint = format!("v1/responses/{}", response_id);
|
let endpoint = format!("v1/responses/{}", response_id);
|
||||||
self.route_get_request(headers, &endpoint).await
|
self.route_get_request(headers, &endpoint).await
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ use std::fmt::Debug;
|
|||||||
|
|
||||||
use crate::protocols::spec::{
|
use crate::protocols::spec::{
|
||||||
ChatCompletionRequest, CompletionRequest, EmbeddingRequest, GenerateRequest, RerankRequest,
|
ChatCompletionRequest, CompletionRequest, EmbeddingRequest, GenerateRequest, RerankRequest,
|
||||||
ResponsesRequest,
|
ResponsesGetParams, ResponsesRequest,
|
||||||
};
|
};
|
||||||
|
|
||||||
pub mod factory;
|
pub mod factory;
|
||||||
@@ -82,7 +82,12 @@ pub trait RouterTrait: Send + Sync + Debug {
|
|||||||
) -> Response;
|
) -> Response;
|
||||||
|
|
||||||
/// Retrieve a stored/background response by id
|
/// Retrieve a stored/background response by id
|
||||||
async fn get_response(&self, headers: Option<&HeaderMap>, response_id: &str) -> Response;
|
async fn get_response(
|
||||||
|
&self,
|
||||||
|
headers: Option<&HeaderMap>,
|
||||||
|
response_id: &str,
|
||||||
|
params: &ResponsesGetParams,
|
||||||
|
) -> Response;
|
||||||
|
|
||||||
/// Cancel a background response by id
|
/// Cancel a background response by id
|
||||||
async fn cancel_response(&self, headers: Option<&HeaderMap>, response_id: &str) -> Response;
|
async fn cancel_response(&self, headers: Option<&HeaderMap>, response_id: &str) -> Response;
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ use crate::config::{ConnectionMode, RoutingMode};
|
|||||||
use crate::core::{WorkerRegistry, WorkerType};
|
use crate::core::{WorkerRegistry, WorkerType};
|
||||||
use crate::protocols::spec::{
|
use crate::protocols::spec::{
|
||||||
ChatCompletionRequest, CompletionRequest, EmbeddingRequest, GenerateRequest, RerankRequest,
|
ChatCompletionRequest, CompletionRequest, EmbeddingRequest, GenerateRequest, RerankRequest,
|
||||||
ResponsesRequest,
|
ResponsesGetParams, ResponsesRequest,
|
||||||
};
|
};
|
||||||
use crate::routers::RouterTrait;
|
use crate::routers::RouterTrait;
|
||||||
use crate::server::{AppContext, ServerConfig};
|
use crate::server::{AppContext, ServerConfig};
|
||||||
@@ -403,38 +403,19 @@ impl RouterTrait for RouterManager {
|
|||||||
|
|
||||||
async fn route_responses(
|
async fn route_responses(
|
||||||
&self,
|
&self,
|
||||||
_headers: Option<&HeaderMap>,
|
headers: Option<&HeaderMap>,
|
||||||
_body: &ResponsesRequest,
|
body: &ResponsesRequest,
|
||||||
_model_id: Option<&str>,
|
model_id: Option<&str>,
|
||||||
) -> Response {
|
) -> Response {
|
||||||
(
|
let selected_model = body.model.as_deref().or(model_id);
|
||||||
StatusCode::NOT_IMPLEMENTED,
|
let router = self.select_router_for_request(headers, selected_model);
|
||||||
"responses api not yet implemented in inference gateway mode",
|
|
||||||
)
|
|
||||||
.into_response()
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn get_response(&self, headers: Option<&HeaderMap>, response_id: &str) -> Response {
|
|
||||||
let router = self.select_router_for_request(headers, None);
|
|
||||||
if let Some(router) = router {
|
if let Some(router) = router {
|
||||||
router.get_response(headers, response_id).await
|
router.route_responses(headers, body, selected_model).await
|
||||||
} else {
|
} else {
|
||||||
(
|
(
|
||||||
StatusCode::NOT_FOUND,
|
StatusCode::NOT_FOUND,
|
||||||
format!("No router available to get response '{}'", response_id),
|
"No router available to handle responses request",
|
||||||
)
|
|
||||||
.into_response()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn cancel_response(&self, headers: Option<&HeaderMap>, response_id: &str) -> Response {
|
|
||||||
let router = self.select_router_for_request(headers, None);
|
|
||||||
if let Some(router) = router {
|
|
||||||
router.cancel_response(headers, response_id).await
|
|
||||||
} else {
|
|
||||||
(
|
|
||||||
StatusCode::NOT_FOUND,
|
|
||||||
format!("No router available to cancel response '{}'", response_id),
|
|
||||||
)
|
)
|
||||||
.into_response()
|
.into_response()
|
||||||
}
|
}
|
||||||
@@ -460,6 +441,37 @@ impl RouterTrait for RouterManager {
|
|||||||
.into_response()
|
.into_response()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn get_response(
|
||||||
|
&self,
|
||||||
|
headers: Option<&HeaderMap>,
|
||||||
|
response_id: &str,
|
||||||
|
params: &ResponsesGetParams,
|
||||||
|
) -> Response {
|
||||||
|
let router = self.select_router_for_request(headers, None);
|
||||||
|
if let Some(router) = router {
|
||||||
|
router.get_response(headers, response_id, params).await
|
||||||
|
} else {
|
||||||
|
(
|
||||||
|
StatusCode::NOT_FOUND,
|
||||||
|
format!("No router available to get response '{}'", response_id),
|
||||||
|
)
|
||||||
|
.into_response()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn cancel_response(&self, headers: Option<&HeaderMap>, response_id: &str) -> Response {
|
||||||
|
let router = self.select_router_for_request(headers, None);
|
||||||
|
if let Some(router) = router {
|
||||||
|
router.cancel_response(headers, response_id).await
|
||||||
|
} else {
|
||||||
|
(
|
||||||
|
StatusCode::NOT_FOUND,
|
||||||
|
format!("No router available to cancel response '{}'", response_id),
|
||||||
|
)
|
||||||
|
.into_response()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
async fn route_embeddings(
|
async fn route_embeddings(
|
||||||
&self,
|
&self,
|
||||||
headers: Option<&HeaderMap>,
|
headers: Option<&HeaderMap>,
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ use crate::{
|
|||||||
protocols::{
|
protocols::{
|
||||||
spec::{
|
spec::{
|
||||||
ChatCompletionRequest, CompletionRequest, EmbeddingRequest, GenerateRequest,
|
ChatCompletionRequest, CompletionRequest, EmbeddingRequest, GenerateRequest,
|
||||||
RerankRequest, ResponsesRequest, V1RerankReqInput,
|
RerankRequest, ResponsesGetParams, ResponsesRequest, V1RerankReqInput,
|
||||||
},
|
},
|
||||||
worker_spec::{WorkerApiResponse, WorkerConfigRequest, WorkerErrorResponse},
|
worker_spec::{WorkerApiResponse, WorkerConfigRequest, WorkerErrorResponse},
|
||||||
},
|
},
|
||||||
@@ -224,10 +224,11 @@ async fn v1_responses_get(
|
|||||||
State(state): State<Arc<AppState>>,
|
State(state): State<Arc<AppState>>,
|
||||||
Path(response_id): Path<String>,
|
Path(response_id): Path<String>,
|
||||||
headers: http::HeaderMap,
|
headers: http::HeaderMap,
|
||||||
|
Query(params): Query<ResponsesGetParams>,
|
||||||
) -> Response {
|
) -> Response {
|
||||||
state
|
state
|
||||||
.router
|
.router
|
||||||
.get_response(Some(&headers), &response_id)
|
.get_response(Some(&headers), &response_id, ¶ms)
|
||||||
.await
|
.await
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -5,17 +5,23 @@ use axum::{
|
|||||||
extract::Request,
|
extract::Request,
|
||||||
http::{Method, StatusCode},
|
http::{Method, StatusCode},
|
||||||
routing::post,
|
routing::post,
|
||||||
Router,
|
Json, Router,
|
||||||
};
|
};
|
||||||
use serde_json::json;
|
use serde_json::json;
|
||||||
use sglang_router_rs::{
|
use sglang_router_rs::{
|
||||||
config::{RouterConfig, RoutingMode},
|
config::{RouterConfig, RoutingMode},
|
||||||
|
data_connector::{MemoryResponseStorage, ResponseId, ResponseStorage},
|
||||||
protocols::spec::{
|
protocols::spec::{
|
||||||
ChatCompletionRequest, ChatMessage, CompletionRequest, GenerateRequest, UserMessageContent,
|
ChatCompletionRequest, ChatMessage, CompletionRequest, GenerateRequest, ResponseInput,
|
||||||
|
ResponsesGetParams, ResponsesRequest, UserMessageContent,
|
||||||
},
|
},
|
||||||
routers::{openai_router::OpenAIRouter, RouterTrait},
|
routers::{openai_router::OpenAIRouter, RouterTrait},
|
||||||
};
|
};
|
||||||
use std::sync::Arc;
|
use std::sync::{
|
||||||
|
atomic::{AtomicUsize, Ordering},
|
||||||
|
Arc,
|
||||||
|
};
|
||||||
|
use tokio::net::TcpListener;
|
||||||
use tower::ServiceExt;
|
use tower::ServiceExt;
|
||||||
|
|
||||||
mod common;
|
mod common;
|
||||||
@@ -78,7 +84,12 @@ fn create_minimal_completion_request() -> CompletionRequest {
|
|||||||
/// Test basic OpenAI router creation and configuration
|
/// Test basic OpenAI router creation and configuration
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_openai_router_creation() {
|
async fn test_openai_router_creation() {
|
||||||
let router = OpenAIRouter::new("https://api.openai.com".to_string(), None).await;
|
let router = OpenAIRouter::new(
|
||||||
|
"https://api.openai.com".to_string(),
|
||||||
|
None,
|
||||||
|
Arc::new(MemoryResponseStorage::new()),
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
|
||||||
assert!(router.is_ok(), "Router creation should succeed");
|
assert!(router.is_ok(), "Router creation should succeed");
|
||||||
|
|
||||||
@@ -90,9 +101,13 @@ async fn test_openai_router_creation() {
|
|||||||
/// Test health endpoints
|
/// Test health endpoints
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_openai_router_health() {
|
async fn test_openai_router_health() {
|
||||||
let router = OpenAIRouter::new("https://api.openai.com".to_string(), None)
|
let router = OpenAIRouter::new(
|
||||||
.await
|
"https://api.openai.com".to_string(),
|
||||||
.unwrap();
|
None,
|
||||||
|
Arc::new(MemoryResponseStorage::new()),
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
let req = Request::builder()
|
let req = Request::builder()
|
||||||
.method(Method::GET)
|
.method(Method::GET)
|
||||||
@@ -107,9 +122,13 @@ async fn test_openai_router_health() {
|
|||||||
/// Test server info endpoint
|
/// Test server info endpoint
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_openai_router_server_info() {
|
async fn test_openai_router_server_info() {
|
||||||
let router = OpenAIRouter::new("https://api.openai.com".to_string(), None)
|
let router = OpenAIRouter::new(
|
||||||
.await
|
"https://api.openai.com".to_string(),
|
||||||
.unwrap();
|
None,
|
||||||
|
Arc::new(MemoryResponseStorage::new()),
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
let req = Request::builder()
|
let req = Request::builder()
|
||||||
.method(Method::GET)
|
.method(Method::GET)
|
||||||
@@ -132,9 +151,13 @@ async fn test_openai_router_server_info() {
|
|||||||
async fn test_openai_router_models() {
|
async fn test_openai_router_models() {
|
||||||
// Use mock server for deterministic models response
|
// Use mock server for deterministic models response
|
||||||
let mock_server = MockOpenAIServer::new().await;
|
let mock_server = MockOpenAIServer::new().await;
|
||||||
let router = OpenAIRouter::new(mock_server.base_url(), None)
|
let router = OpenAIRouter::new(
|
||||||
.await
|
mock_server.base_url(),
|
||||||
.unwrap();
|
None,
|
||||||
|
Arc::new(MemoryResponseStorage::new()),
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
let req = Request::builder()
|
let req = Request::builder()
|
||||||
.method(Method::GET)
|
.method(Method::GET)
|
||||||
@@ -154,6 +177,138 @@ async fn test_openai_router_models() {
|
|||||||
assert!(models["data"].is_array());
|
assert!(models["data"].is_array());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_openai_router_responses_with_mock() {
|
||||||
|
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||||
|
let addr = listener.local_addr().unwrap();
|
||||||
|
let counter = Arc::new(AtomicUsize::new(0));
|
||||||
|
let counter_clone = counter.clone();
|
||||||
|
|
||||||
|
let app = Router::new().route(
|
||||||
|
"/v1/responses",
|
||||||
|
post({
|
||||||
|
move |Json(request): Json<serde_json::Value>| {
|
||||||
|
let counter = counter_clone.clone();
|
||||||
|
async move {
|
||||||
|
let idx = counter.fetch_add(1, Ordering::SeqCst) + 1;
|
||||||
|
let model = request
|
||||||
|
.get("model")
|
||||||
|
.and_then(|v| v.as_str())
|
||||||
|
.unwrap_or("gpt-4o-mini")
|
||||||
|
.to_string();
|
||||||
|
let id = format!("resp_mock_{idx}");
|
||||||
|
let response = json!({
|
||||||
|
"id": id,
|
||||||
|
"object": "response",
|
||||||
|
"created_at": 1_700_000_000 + idx as i64,
|
||||||
|
"status": "completed",
|
||||||
|
"model": model,
|
||||||
|
"output": [{
|
||||||
|
"type": "message",
|
||||||
|
"id": format!("msg_{idx}"),
|
||||||
|
"role": "assistant",
|
||||||
|
"status": "completed",
|
||||||
|
"content": [{
|
||||||
|
"type": "output_text",
|
||||||
|
"text": format!("mock_output_{idx}"),
|
||||||
|
"annotations": []
|
||||||
|
}]
|
||||||
|
}],
|
||||||
|
"metadata": {}
|
||||||
|
});
|
||||||
|
Json(response)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}),
|
||||||
|
);
|
||||||
|
|
||||||
|
let server = tokio::spawn(async move {
|
||||||
|
axum::serve(listener, app).await.unwrap();
|
||||||
|
});
|
||||||
|
|
||||||
|
let base_url = format!("http://{}", addr);
|
||||||
|
let storage = Arc::new(MemoryResponseStorage::new());
|
||||||
|
|
||||||
|
let router = OpenAIRouter::new(base_url, None, storage.clone())
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let request1 = ResponsesRequest {
|
||||||
|
model: Some("gpt-4o-mini".to_string()),
|
||||||
|
input: ResponseInput::Text("Say hi".to_string()),
|
||||||
|
store: true,
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
|
||||||
|
let response1 = router.route_responses(None, &request1, None).await;
|
||||||
|
assert_eq!(response1.status(), StatusCode::OK);
|
||||||
|
let body1_bytes = axum::body::to_bytes(response1.into_body(), usize::MAX)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
let body1: serde_json::Value = serde_json::from_slice(&body1_bytes).unwrap();
|
||||||
|
let resp1_id = body1["id"].as_str().expect("id missing").to_string();
|
||||||
|
assert_eq!(body1["previous_response_id"], serde_json::Value::Null);
|
||||||
|
|
||||||
|
let request2 = ResponsesRequest {
|
||||||
|
model: Some("gpt-4o-mini".to_string()),
|
||||||
|
input: ResponseInput::Text("Thanks".to_string()),
|
||||||
|
store: true,
|
||||||
|
previous_response_id: Some(resp1_id.clone()),
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
|
||||||
|
let response2 = router.route_responses(None, &request2, None).await;
|
||||||
|
assert_eq!(response2.status(), StatusCode::OK);
|
||||||
|
let body2_bytes = axum::body::to_bytes(response2.into_body(), usize::MAX)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
let body2: serde_json::Value = serde_json::from_slice(&body2_bytes).unwrap();
|
||||||
|
let resp2_id = body2["id"].as_str().expect("second id missing");
|
||||||
|
assert_eq!(
|
||||||
|
body2["previous_response_id"].as_str(),
|
||||||
|
Some(resp1_id.as_str())
|
||||||
|
);
|
||||||
|
|
||||||
|
let stored1 = storage
|
||||||
|
.get_response(&ResponseId::from_string(resp1_id.clone()))
|
||||||
|
.await
|
||||||
|
.unwrap()
|
||||||
|
.expect("first response missing");
|
||||||
|
assert_eq!(stored1.input, "Say hi");
|
||||||
|
assert_eq!(stored1.output, "mock_output_1");
|
||||||
|
assert!(stored1.previous_response_id.is_none());
|
||||||
|
|
||||||
|
let stored2 = storage
|
||||||
|
.get_response(&ResponseId::from_string(resp2_id.to_string()))
|
||||||
|
.await
|
||||||
|
.unwrap()
|
||||||
|
.expect("second response missing");
|
||||||
|
assert_eq!(stored2.previous_response_id.unwrap().0, resp1_id);
|
||||||
|
assert_eq!(stored2.output, "mock_output_2");
|
||||||
|
|
||||||
|
let get1 = router
|
||||||
|
.get_response(None, &stored1.id.0, &ResponsesGetParams::default())
|
||||||
|
.await;
|
||||||
|
assert_eq!(get1.status(), StatusCode::OK);
|
||||||
|
let get1_body_bytes = axum::body::to_bytes(get1.into_body(), usize::MAX)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
let get1_json: serde_json::Value = serde_json::from_slice(&get1_body_bytes).unwrap();
|
||||||
|
assert_eq!(get1_json, body1);
|
||||||
|
|
||||||
|
let get2 = router
|
||||||
|
.get_response(None, &stored2.id.0, &ResponsesGetParams::default())
|
||||||
|
.await;
|
||||||
|
assert_eq!(get2.status(), StatusCode::OK);
|
||||||
|
let get2_body_bytes = axum::body::to_bytes(get2.into_body(), usize::MAX)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
let get2_json: serde_json::Value = serde_json::from_slice(&get2_body_bytes).unwrap();
|
||||||
|
assert_eq!(get2_json, body2);
|
||||||
|
|
||||||
|
server.abort();
|
||||||
|
}
|
||||||
|
|
||||||
/// Test router factory with OpenAI routing mode
|
/// Test router factory with OpenAI routing mode
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_router_factory_openai_mode() {
|
async fn test_router_factory_openai_mode() {
|
||||||
@@ -179,9 +334,13 @@ async fn test_router_factory_openai_mode() {
|
|||||||
/// Test that unsupported endpoints return proper error codes
|
/// Test that unsupported endpoints return proper error codes
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_unsupported_endpoints() {
|
async fn test_unsupported_endpoints() {
|
||||||
let router = OpenAIRouter::new("https://api.openai.com".to_string(), None)
|
let router = OpenAIRouter::new(
|
||||||
.await
|
"https://api.openai.com".to_string(),
|
||||||
.unwrap();
|
None,
|
||||||
|
Arc::new(MemoryResponseStorage::new()),
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
// Test generate endpoint (SGLang-specific, should not be supported)
|
// Test generate endpoint (SGLang-specific, should not be supported)
|
||||||
let generate_request = GenerateRequest {
|
let generate_request = GenerateRequest {
|
||||||
@@ -219,7 +378,9 @@ async fn test_openai_router_chat_completion_with_mock() {
|
|||||||
let base_url = mock_server.base_url();
|
let base_url = mock_server.base_url();
|
||||||
|
|
||||||
// Create router pointing to mock server
|
// Create router pointing to mock server
|
||||||
let router = OpenAIRouter::new(base_url, None).await.unwrap();
|
let router = OpenAIRouter::new(base_url, None, Arc::new(MemoryResponseStorage::new()))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
// Create a minimal chat completion request
|
// Create a minimal chat completion request
|
||||||
let mut chat_request = create_minimal_chat_request();
|
let mut chat_request = create_minimal_chat_request();
|
||||||
@@ -255,7 +416,9 @@ async fn test_openai_e2e_with_server() {
|
|||||||
let base_url = mock_server.base_url();
|
let base_url = mock_server.base_url();
|
||||||
|
|
||||||
// Create router
|
// Create router
|
||||||
let router = OpenAIRouter::new(base_url, None).await.unwrap();
|
let router = OpenAIRouter::new(base_url, None, Arc::new(MemoryResponseStorage::new()))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
// Create Axum app with chat completions endpoint
|
// Create Axum app with chat completions endpoint
|
||||||
let app = Router::new().route(
|
let app = Router::new().route(
|
||||||
@@ -319,7 +482,9 @@ async fn test_openai_e2e_with_server() {
|
|||||||
async fn test_openai_router_chat_streaming_with_mock() {
|
async fn test_openai_router_chat_streaming_with_mock() {
|
||||||
let mock_server = MockOpenAIServer::new().await;
|
let mock_server = MockOpenAIServer::new().await;
|
||||||
let base_url = mock_server.base_url();
|
let base_url = mock_server.base_url();
|
||||||
let router = OpenAIRouter::new(base_url, None).await.unwrap();
|
let router = OpenAIRouter::new(base_url, None, Arc::new(MemoryResponseStorage::new()))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
// Build a streaming chat request
|
// Build a streaming chat request
|
||||||
let val = json!({
|
let val = json!({
|
||||||
@@ -368,6 +533,7 @@ async fn test_openai_router_circuit_breaker() {
|
|||||||
let router = OpenAIRouter::new(
|
let router = OpenAIRouter::new(
|
||||||
"http://invalid-url-that-will-fail".to_string(),
|
"http://invalid-url-that-will-fail".to_string(),
|
||||||
Some(cb_config),
|
Some(cb_config),
|
||||||
|
Arc::new(MemoryResponseStorage::new()),
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
@@ -391,9 +557,13 @@ async fn test_openai_router_models_auth_forwarding() {
|
|||||||
// Start a mock server that requires Authorization
|
// Start a mock server that requires Authorization
|
||||||
let expected_auth = "Bearer test-token".to_string();
|
let expected_auth = "Bearer test-token".to_string();
|
||||||
let mock_server = MockOpenAIServer::new_with_auth(Some(expected_auth.clone())).await;
|
let mock_server = MockOpenAIServer::new_with_auth(Some(expected_auth.clone())).await;
|
||||||
let router = OpenAIRouter::new(mock_server.base_url(), None)
|
let router = OpenAIRouter::new(
|
||||||
.await
|
mock_server.base_url(),
|
||||||
.unwrap();
|
None,
|
||||||
|
Arc::new(MemoryResponseStorage::new()),
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
// 1) Without auth header -> expect 401
|
// 1) Without auth header -> expect 401
|
||||||
let req = Request::builder()
|
let req = Request::builder()
|
||||||
|
|||||||
Reference in New Issue
Block a user