[router] openai router: support grok model (#11511)
This commit is contained in:
@@ -1073,8 +1073,8 @@ fn generate_request_id() -> String {
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct ResponsesRequest {
|
||||
/// Run the request in the background
|
||||
#[serde(default)]
|
||||
pub background: bool,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub background: Option<bool>,
|
||||
|
||||
/// Fields to include in the response
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
@@ -1108,8 +1108,8 @@ pub struct ResponsesRequest {
|
||||
pub conversation: Option<String>,
|
||||
|
||||
/// Whether to enable parallel tool calls
|
||||
#[serde(default = "default_true")]
|
||||
pub parallel_tool_calls: bool,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub parallel_tool_calls: Option<bool>,
|
||||
|
||||
/// ID of previous response to continue from
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
@@ -1120,40 +1120,40 @@ pub struct ResponsesRequest {
|
||||
pub reasoning: Option<ResponseReasoningParam>,
|
||||
|
||||
/// Service tier
|
||||
#[serde(default)]
|
||||
pub service_tier: ServiceTier,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub service_tier: Option<ServiceTier>,
|
||||
|
||||
/// Whether to store the response
|
||||
#[serde(default = "default_true")]
|
||||
pub store: bool,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub store: Option<bool>,
|
||||
|
||||
/// Whether to stream the response
|
||||
#[serde(default)]
|
||||
pub stream: bool,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub stream: Option<bool>,
|
||||
|
||||
/// Temperature for sampling
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub temperature: Option<f32>,
|
||||
|
||||
/// Tool choice behavior
|
||||
#[serde(default)]
|
||||
pub tool_choice: ToolChoice,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tool_choice: Option<ToolChoice>,
|
||||
|
||||
/// Available tools
|
||||
#[serde(default)]
|
||||
pub tools: Vec<ResponseTool>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tools: Option<Vec<ResponseTool>>,
|
||||
|
||||
/// Number of top logprobs to return
|
||||
#[serde(default)]
|
||||
pub top_logprobs: u32,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub top_logprobs: Option<u32>,
|
||||
|
||||
/// Top-p sampling parameter
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub top_p: Option<f32>,
|
||||
|
||||
/// Truncation behavior
|
||||
#[serde(default)]
|
||||
pub truncation: Truncation,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub truncation: Option<Truncation>,
|
||||
|
||||
/// User identifier
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
@@ -1168,12 +1168,12 @@ pub struct ResponsesRequest {
|
||||
pub priority: i32,
|
||||
|
||||
/// Frequency penalty
|
||||
#[serde(default)]
|
||||
pub frequency_penalty: f32,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub frequency_penalty: Option<f32>,
|
||||
|
||||
/// Presence penalty
|
||||
#[serde(default)]
|
||||
pub presence_penalty: f32,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub presence_penalty: Option<f32>,
|
||||
|
||||
/// Stop sequences
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
@@ -1210,7 +1210,7 @@ fn default_repetition_penalty() -> f32 {
|
||||
impl Default for ResponsesRequest {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
background: false,
|
||||
background: None,
|
||||
include: None,
|
||||
input: ResponseInput::Text(String::new()),
|
||||
instructions: None,
|
||||
@@ -1219,23 +1219,23 @@ impl Default for ResponsesRequest {
|
||||
metadata: None,
|
||||
model: None,
|
||||
conversation: None,
|
||||
parallel_tool_calls: true,
|
||||
parallel_tool_calls: None,
|
||||
previous_response_id: None,
|
||||
reasoning: None,
|
||||
service_tier: ServiceTier::default(),
|
||||
store: true,
|
||||
stream: false,
|
||||
service_tier: None,
|
||||
store: None,
|
||||
stream: None,
|
||||
temperature: None,
|
||||
tool_choice: ToolChoice::default(),
|
||||
tools: Vec::new(),
|
||||
top_logprobs: 0,
|
||||
tool_choice: None,
|
||||
tools: None,
|
||||
top_logprobs: None,
|
||||
top_p: None,
|
||||
truncation: Truncation::default(),
|
||||
truncation: None,
|
||||
user: None,
|
||||
request_id: generate_request_id(),
|
||||
priority: 0,
|
||||
frequency_penalty: 0.0,
|
||||
presence_penalty: 0.0,
|
||||
frequency_penalty: None,
|
||||
presence_penalty: None,
|
||||
stop: None,
|
||||
top_k: default_top_k(),
|
||||
min_p: 0.0,
|
||||
@@ -1299,14 +1299,18 @@ impl ResponsesRequest {
|
||||
"top_p".to_string(),
|
||||
Value::Number(Number::from_f64(top_p as f64).unwrap()),
|
||||
);
|
||||
params.insert(
|
||||
"frequency_penalty".to_string(),
|
||||
Value::Number(Number::from_f64(self.frequency_penalty as f64).unwrap()),
|
||||
);
|
||||
params.insert(
|
||||
"presence_penalty".to_string(),
|
||||
Value::Number(Number::from_f64(self.presence_penalty as f64).unwrap()),
|
||||
);
|
||||
if let Some(fp) = self.frequency_penalty {
|
||||
params.insert(
|
||||
"frequency_penalty".to_string(),
|
||||
Value::Number(Number::from_f64(fp as f64).unwrap()),
|
||||
);
|
||||
}
|
||||
if let Some(pp) = self.presence_penalty {
|
||||
params.insert(
|
||||
"presence_penalty".to_string(),
|
||||
Value::Number(Number::from_f64(pp as f64).unwrap()),
|
||||
);
|
||||
}
|
||||
params.insert("top_k".to_string(), Value::Number(Number::from(self.top_k)));
|
||||
params.insert(
|
||||
"min_p".to_string(),
|
||||
@@ -1337,7 +1341,7 @@ impl ResponsesRequest {
|
||||
|
||||
impl GenerationRequest for ResponsesRequest {
|
||||
fn is_stream(&self) -> bool {
|
||||
self.stream
|
||||
self.stream.unwrap_or(false)
|
||||
}
|
||||
|
||||
fn get_model(&self) -> Option<&str> {
|
||||
@@ -1523,13 +1527,13 @@ impl ResponsesResponse {
|
||||
max_output_tokens: request.max_output_tokens,
|
||||
model: model_name,
|
||||
output,
|
||||
parallel_tool_calls: request.parallel_tool_calls,
|
||||
parallel_tool_calls: request.parallel_tool_calls.unwrap_or(true),
|
||||
previous_response_id: request.previous_response_id.clone(),
|
||||
reasoning: request.reasoning.as_ref().map(|r| ReasoningInfo {
|
||||
effort: r.effort.as_ref().map(|e| format!("{:?}", e)),
|
||||
summary: None,
|
||||
}),
|
||||
store: request.store,
|
||||
store: request.store.unwrap_or(false),
|
||||
temperature: request.temperature,
|
||||
text: Some(ResponseTextFormat {
|
||||
format: TextFormatType {
|
||||
@@ -1537,17 +1541,19 @@ impl ResponsesResponse {
|
||||
},
|
||||
}),
|
||||
tool_choice: match &request.tool_choice {
|
||||
ToolChoice::Value(ToolChoiceValue::Auto) => "auto".to_string(),
|
||||
ToolChoice::Value(ToolChoiceValue::Required) => "required".to_string(),
|
||||
ToolChoice::Value(ToolChoiceValue::None) => "none".to_string(),
|
||||
ToolChoice::Function { .. } => "function".to_string(),
|
||||
ToolChoice::AllowedTools { mode, .. } => mode.clone(),
|
||||
Some(ToolChoice::Value(ToolChoiceValue::Auto)) => "auto".to_string(),
|
||||
Some(ToolChoice::Value(ToolChoiceValue::Required)) => "required".to_string(),
|
||||
Some(ToolChoice::Value(ToolChoiceValue::None)) => "none".to_string(),
|
||||
Some(ToolChoice::Function { .. }) => "function".to_string(),
|
||||
Some(ToolChoice::AllowedTools { mode, .. }) => mode.clone(),
|
||||
None => "auto".to_string(),
|
||||
},
|
||||
tools: request.tools.clone(),
|
||||
tools: request.tools.clone().unwrap_or_default(),
|
||||
top_p: request.top_p,
|
||||
truncation: match &request.truncation {
|
||||
Truncation::Auto => Some("auto".to_string()),
|
||||
Truncation::Disabled => Some("disabled".to_string()),
|
||||
Some(Truncation::Auto) => Some("auto".to_string()),
|
||||
Some(Truncation::Disabled) => Some("disabled".to_string()),
|
||||
None => None,
|
||||
},
|
||||
usage: usage.map(ResponsesUsage::Classic),
|
||||
user: request.user.clone(),
|
||||
|
||||
@@ -689,9 +689,13 @@ pub(super) async fn execute_tool_loop(
|
||||
if state.total_calls > 0 {
|
||||
let server_label = original_body
|
||||
.tools
|
||||
.iter()
|
||||
.find(|t| matches!(t.r#type, ResponseToolType::Mcp))
|
||||
.and_then(|t| t.server_label.as_deref())
|
||||
.as_ref()
|
||||
.and_then(|tools| {
|
||||
tools
|
||||
.iter()
|
||||
.find(|t| matches!(t.r#type, ResponseToolType::Mcp))
|
||||
.and_then(|t| t.server_label.as_deref())
|
||||
})
|
||||
.unwrap_or("mcp");
|
||||
|
||||
// Build mcp_list_tools item
|
||||
@@ -747,9 +751,13 @@ pub(super) fn build_incomplete_response(
|
||||
if let Some(output_array) = obj.get_mut("output").and_then(|v| v.as_array_mut()) {
|
||||
let server_label = original_body
|
||||
.tools
|
||||
.iter()
|
||||
.find(|t| matches!(t.r#type, ResponseToolType::Mcp))
|
||||
.and_then(|t| t.server_label.as_deref())
|
||||
.as_ref()
|
||||
.and_then(|tools| {
|
||||
tools
|
||||
.iter()
|
||||
.find(|t| matches!(t.r#type, ResponseToolType::Mcp))
|
||||
.and_then(|t| t.server_label.as_deref())
|
||||
})
|
||||
.unwrap_or("mcp");
|
||||
|
||||
// Find any function_call items and convert them to mcp_call (incomplete)
|
||||
|
||||
@@ -129,7 +129,10 @@ pub(super) fn patch_streaming_response_json(
|
||||
}
|
||||
}
|
||||
|
||||
obj.insert("store".to_string(), Value::Bool(original_body.store));
|
||||
obj.insert(
|
||||
"store".to_string(),
|
||||
Value::Bool(original_body.store.unwrap_or(false)),
|
||||
);
|
||||
|
||||
if obj
|
||||
.get("model")
|
||||
@@ -205,7 +208,7 @@ pub(super) fn rewrite_streaming_block(
|
||||
|
||||
let mut changed = false;
|
||||
if let Some(response_obj) = parsed.get_mut("response").and_then(|v| v.as_object_mut()) {
|
||||
let desired_store = Value::Bool(original_body.store);
|
||||
let desired_store = Value::Bool(original_body.store.unwrap_or(false));
|
||||
if response_obj.get("store") != Some(&desired_store) {
|
||||
response_obj.insert("store".to_string(), desired_store);
|
||||
changed = true;
|
||||
@@ -267,10 +270,11 @@ pub(super) fn rewrite_streaming_block(
|
||||
|
||||
/// Mask function tools as MCP tools in response for client
|
||||
pub(super) fn mask_tools_as_mcp(resp: &mut Value, original_body: &ResponsesRequest) {
|
||||
let mcp_tool = original_body
|
||||
.tools
|
||||
.iter()
|
||||
.find(|t| matches!(t.r#type, ResponseToolType::Mcp) && t.server_url.is_some());
|
||||
let mcp_tool = original_body.tools.as_ref().and_then(|tools| {
|
||||
tools
|
||||
.iter()
|
||||
.find(|t| matches!(t.r#type, ResponseToolType::Mcp) && t.server_url.is_some())
|
||||
});
|
||||
let Some(t) = mcp_tool else {
|
||||
return;
|
||||
};
|
||||
|
||||
@@ -148,7 +148,11 @@ impl OpenAIRouter {
|
||||
original_previous_response_id: Option<String>,
|
||||
) -> Response {
|
||||
// Check if MCP is active for this request
|
||||
let req_mcp_manager = mcp_manager_from_request_tools(&original_body.tools).await;
|
||||
let req_mcp_manager = if let Some(ref tools) = original_body.tools {
|
||||
mcp_manager_from_request_tools(tools.as_slice()).await
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let active_mcp = req_mcp_manager.as_ref().or(self.mcp_manager.as_ref());
|
||||
|
||||
let mut response_json: Value;
|
||||
@@ -183,6 +187,7 @@ impl OpenAIRouter {
|
||||
}
|
||||
} else {
|
||||
// No MCP - simple request
|
||||
|
||||
let mut request_builder = self.client.post(&url).json(&payload);
|
||||
if let Some(h) = headers {
|
||||
request_builder = apply_request_headers(h, request_builder, true);
|
||||
@@ -385,6 +390,7 @@ impl crate::routers::RouterTrait for OpenAIRouter {
|
||||
}
|
||||
};
|
||||
if let Some(obj) = payload.as_object_mut() {
|
||||
// Always remove SGLang-specific fields (unsupported by OpenAI)
|
||||
for key in [
|
||||
"top_k",
|
||||
"min_p",
|
||||
@@ -535,7 +541,7 @@ impl crate::routers::RouterTrait for OpenAIRouter {
|
||||
.into_response();
|
||||
}
|
||||
|
||||
// Clone the body and override model if needed
|
||||
// Clone the body for validation and logic, but we'll build payload differently
|
||||
let mut request_body = body.clone();
|
||||
if let Some(model) = model_id {
|
||||
request_body.model = Some(model.to_string());
|
||||
@@ -690,7 +696,7 @@ impl crate::routers::RouterTrait for OpenAIRouter {
|
||||
}
|
||||
|
||||
// Always set store=false for upstream (we store internally)
|
||||
request_body.store = false;
|
||||
request_body.store = Some(false);
|
||||
|
||||
// Convert to JSON and strip SGLang-specific fields
|
||||
let mut payload = match to_value(&request_body) {
|
||||
@@ -704,14 +710,13 @@ impl crate::routers::RouterTrait for OpenAIRouter {
|
||||
}
|
||||
};
|
||||
|
||||
// Remove SGLang-specific fields
|
||||
// Remove SGLang-specific fields only
|
||||
if let Some(obj) = payload.as_object_mut() {
|
||||
// Remove SGLang-specific fields (not part of OpenAI API)
|
||||
for key in [
|
||||
"request_id",
|
||||
"priority",
|
||||
"top_k",
|
||||
"frequency_penalty",
|
||||
"presence_penalty",
|
||||
"min_p",
|
||||
"min_tokens",
|
||||
"regex",
|
||||
@@ -732,10 +737,38 @@ impl crate::routers::RouterTrait for OpenAIRouter {
|
||||
] {
|
||||
obj.remove(key);
|
||||
}
|
||||
// XAI doesn't support the OPENAI item type input: https://platform.openai.com/docs/api-reference/responses/create#responses-create-input-input-item-list-item
|
||||
// To Achieve XAI compatibility, strip extra fields from input messages (id, status)
|
||||
// XAI doesn't support output_text as type for content with role of assistant
|
||||
// so normalize content types: output_text -> input_text
|
||||
if let Some(input_arr) = obj.get_mut("input").and_then(Value::as_array_mut) {
|
||||
for item_obj in input_arr.iter_mut().filter_map(Value::as_object_mut) {
|
||||
// Remove fields not universally supported
|
||||
item_obj.remove("id");
|
||||
item_obj.remove("status");
|
||||
|
||||
// Normalize content types to input_text (xAI compatibility)
|
||||
if let Some(content_arr) =
|
||||
item_obj.get_mut("content").and_then(Value::as_array_mut)
|
||||
{
|
||||
for content_obj in content_arr.iter_mut().filter_map(Value::as_object_mut) {
|
||||
// Change output_text to input_text
|
||||
if content_obj.get("type").and_then(Value::as_str)
|
||||
== Some("output_text")
|
||||
{
|
||||
content_obj.insert(
|
||||
"type".to_string(),
|
||||
Value::String("input_text".to_string()),
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Delegate to streaming or non-streaming handler
|
||||
if body.stream {
|
||||
if body.stream.unwrap_or(false) {
|
||||
handle_streaming_response(
|
||||
&self.client,
|
||||
&self.circuit_breaker,
|
||||
|
||||
@@ -572,7 +572,7 @@ pub(super) fn apply_event_transformations_inplace(
|
||||
.get_mut("response")
|
||||
.and_then(|v| v.as_object_mut())
|
||||
{
|
||||
let desired_store = Value::Bool(original_request.store);
|
||||
let desired_store = Value::Bool(original_request.store.unwrap_or(false));
|
||||
if response_obj.get("store") != Some(&desired_store) {
|
||||
response_obj.insert("store".to_string(), desired_store);
|
||||
changed = true;
|
||||
@@ -597,8 +597,13 @@ pub(super) fn apply_event_transformations_inplace(
|
||||
if response_obj.get("tools").is_some() {
|
||||
let requested_mcp = original_request
|
||||
.tools
|
||||
.iter()
|
||||
.any(|t| matches!(t.r#type, ResponseToolType::Mcp));
|
||||
.as_ref()
|
||||
.map(|tools| {
|
||||
tools
|
||||
.iter()
|
||||
.any(|t| matches!(t.r#type, ResponseToolType::Mcp))
|
||||
})
|
||||
.unwrap_or(false);
|
||||
|
||||
if requested_mcp {
|
||||
if let Some(mcp_tools) = build_mcp_tools_value(original_request) {
|
||||
@@ -658,8 +663,8 @@ pub(super) fn apply_event_transformations_inplace(
|
||||
|
||||
/// Helper to build MCP tools value
|
||||
fn build_mcp_tools_value(original_body: &ResponsesRequest) -> Option<Value> {
|
||||
let mcp_tool = original_body
|
||||
.tools
|
||||
let tools = original_body.tools.as_ref()?;
|
||||
let mcp_tool = tools
|
||||
.iter()
|
||||
.find(|t| matches!(t.r#type, ResponseToolType::Mcp) && t.server_url.is_some())?;
|
||||
|
||||
@@ -1000,7 +1005,7 @@ pub(super) async fn handle_simple_streaming_passthrough(
|
||||
|
||||
let (tx, rx) = mpsc::unbounded_channel::<Result<Bytes, io::Error>>();
|
||||
|
||||
let should_store = original_body.store;
|
||||
let should_store = original_body.store.unwrap_or(false);
|
||||
let original_request = original_body.clone();
|
||||
let persist_needed = original_request.conversation.is_some();
|
||||
let previous_response_id = original_previous_response_id.clone();
|
||||
@@ -1134,7 +1139,7 @@ pub(super) async fn handle_streaming_with_tool_interception(
|
||||
prepare_mcp_payload_for_streaming(&mut payload, active_mcp);
|
||||
|
||||
let (tx, rx) = mpsc::unbounded_channel::<Result<Bytes, io::Error>>();
|
||||
let should_store = original_body.store;
|
||||
let should_store = original_body.store.unwrap_or(false);
|
||||
let original_request = original_body.clone();
|
||||
let persist_needed = original_request.conversation.is_some();
|
||||
let previous_response_id = original_previous_response_id.clone();
|
||||
@@ -1161,9 +1166,13 @@ pub(super) async fn handle_streaming_with_tool_interception(
|
||||
|
||||
let server_label = original_request
|
||||
.tools
|
||||
.iter()
|
||||
.find(|t| matches!(t.r#type, ResponseToolType::Mcp))
|
||||
.and_then(|t| t.server_label.as_deref())
|
||||
.as_ref()
|
||||
.and_then(|tools| {
|
||||
tools
|
||||
.iter()
|
||||
.find(|t| matches!(t.r#type, ResponseToolType::Mcp))
|
||||
.and_then(|t| t.server_label.as_deref())
|
||||
})
|
||||
.unwrap_or("mcp");
|
||||
|
||||
loop {
|
||||
@@ -1488,7 +1497,11 @@ pub(super) async fn handle_streaming_response(
|
||||
original_previous_response_id: Option<String>,
|
||||
) -> Response {
|
||||
// Check if MCP is active for this request
|
||||
let req_mcp_manager = mcp_manager_from_request_tools(&original_body.tools).await;
|
||||
let req_mcp_manager = if let Some(ref tools) = original_body.tools {
|
||||
mcp_manager_from_request_tools(tools.as_slice()).await
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let active_mcp = req_mcp_manager.as_ref().or(mcp_manager);
|
||||
|
||||
// If no MCP is active, use simple pass-through streaming
|
||||
|
||||
@@ -89,7 +89,7 @@ async fn test_non_streaming_mcp_minimal_e2e_with_persistence() {
|
||||
|
||||
// Build a simple ResponsesRequest that will trigger the tool call
|
||||
let req = ResponsesRequest {
|
||||
background: false,
|
||||
background: Some(false),
|
||||
include: None,
|
||||
input: ResponseInput::Text("search something".to_string()),
|
||||
instructions: Some("Be brief".to_string()),
|
||||
@@ -97,15 +97,15 @@ async fn test_non_streaming_mcp_minimal_e2e_with_persistence() {
|
||||
max_tool_calls: None,
|
||||
metadata: None,
|
||||
model: Some("mock-model".to_string()),
|
||||
parallel_tool_calls: true,
|
||||
parallel_tool_calls: Some(true),
|
||||
previous_response_id: None,
|
||||
reasoning: None,
|
||||
service_tier: ServiceTier::Auto,
|
||||
store: true,
|
||||
stream: false,
|
||||
service_tier: Some(ServiceTier::Auto),
|
||||
store: Some(true),
|
||||
stream: Some(false),
|
||||
temperature: Some(0.2),
|
||||
tool_choice: ToolChoice::default(),
|
||||
tools: vec![ResponseTool {
|
||||
tool_choice: Some(ToolChoice::default()),
|
||||
tools: Some(vec![ResponseTool {
|
||||
r#type: ResponseToolType::Mcp,
|
||||
server_url: Some(mcp.url()),
|
||||
authorization: None,
|
||||
@@ -113,15 +113,15 @@ async fn test_non_streaming_mcp_minimal_e2e_with_persistence() {
|
||||
server_description: None,
|
||||
require_approval: None,
|
||||
allowed_tools: None,
|
||||
}],
|
||||
top_logprobs: 0,
|
||||
}]),
|
||||
top_logprobs: Some(0),
|
||||
top_p: None,
|
||||
truncation: Truncation::Disabled,
|
||||
truncation: Some(Truncation::Disabled),
|
||||
user: None,
|
||||
request_id: "resp_test_mcp_e2e".to_string(),
|
||||
priority: 0,
|
||||
frequency_penalty: 0.0,
|
||||
presence_penalty: 0.0,
|
||||
frequency_penalty: Some(0.0),
|
||||
presence_penalty: Some(0.0),
|
||||
stop: None,
|
||||
top_k: -1,
|
||||
min_p: 0.0,
|
||||
@@ -338,7 +338,7 @@ async fn test_conversations_crud_basic() {
|
||||
#[test]
|
||||
fn test_responses_request_creation() {
|
||||
let request = ResponsesRequest {
|
||||
background: false,
|
||||
background: Some(false),
|
||||
include: None,
|
||||
input: ResponseInput::Text("Hello, world!".to_string()),
|
||||
instructions: Some("Be helpful".to_string()),
|
||||
@@ -346,29 +346,29 @@ fn test_responses_request_creation() {
|
||||
max_tool_calls: None,
|
||||
metadata: None,
|
||||
model: Some("test-model".to_string()),
|
||||
parallel_tool_calls: true,
|
||||
parallel_tool_calls: Some(true),
|
||||
previous_response_id: None,
|
||||
reasoning: Some(ResponseReasoningParam {
|
||||
effort: Some(ReasoningEffort::Medium),
|
||||
summary: None,
|
||||
}),
|
||||
service_tier: ServiceTier::Auto,
|
||||
store: true,
|
||||
stream: false,
|
||||
service_tier: Some(ServiceTier::Auto),
|
||||
store: Some(true),
|
||||
stream: Some(false),
|
||||
temperature: Some(0.7),
|
||||
tool_choice: ToolChoice::Value(ToolChoiceValue::Auto),
|
||||
tools: vec![ResponseTool {
|
||||
tool_choice: Some(ToolChoice::Value(ToolChoiceValue::Auto)),
|
||||
tools: Some(vec![ResponseTool {
|
||||
r#type: ResponseToolType::WebSearchPreview,
|
||||
..Default::default()
|
||||
}],
|
||||
top_logprobs: 5,
|
||||
}]),
|
||||
top_logprobs: Some(5),
|
||||
top_p: Some(0.9),
|
||||
truncation: Truncation::Disabled,
|
||||
truncation: Some(Truncation::Disabled),
|
||||
user: Some("test-user".to_string()),
|
||||
request_id: "resp_test123".to_string(),
|
||||
priority: 0,
|
||||
frequency_penalty: 0.0,
|
||||
presence_penalty: 0.0,
|
||||
frequency_penalty: Some(0.0),
|
||||
presence_penalty: Some(0.0),
|
||||
stop: None,
|
||||
top_k: -1,
|
||||
min_p: 0.0,
|
||||
@@ -385,7 +385,7 @@ fn test_responses_request_creation() {
|
||||
#[test]
|
||||
fn test_sampling_params_conversion() {
|
||||
let request = ResponsesRequest {
|
||||
background: false,
|
||||
background: Some(false),
|
||||
include: None,
|
||||
input: ResponseInput::Text("Test".to_string()),
|
||||
instructions: None,
|
||||
@@ -393,23 +393,23 @@ fn test_sampling_params_conversion() {
|
||||
max_tool_calls: None,
|
||||
metadata: None,
|
||||
model: Some("test-model".to_string()),
|
||||
parallel_tool_calls: true, // Use default true
|
||||
parallel_tool_calls: Some(true), // Use default true
|
||||
previous_response_id: None,
|
||||
reasoning: None,
|
||||
service_tier: ServiceTier::Auto,
|
||||
store: true, // Use default true
|
||||
stream: false,
|
||||
service_tier: Some(ServiceTier::Auto),
|
||||
store: Some(true), // Use default true
|
||||
stream: Some(false),
|
||||
temperature: Some(0.8),
|
||||
tool_choice: ToolChoice::Value(ToolChoiceValue::Auto),
|
||||
tools: vec![],
|
||||
top_logprobs: 0, // Use default 0
|
||||
tool_choice: Some(ToolChoice::Value(ToolChoiceValue::Auto)),
|
||||
tools: Some(vec![]),
|
||||
top_logprobs: Some(0), // Use default 0
|
||||
top_p: Some(0.95),
|
||||
truncation: Truncation::Auto,
|
||||
truncation: Some(Truncation::Auto),
|
||||
user: None,
|
||||
request_id: "resp_test456".to_string(),
|
||||
priority: 0,
|
||||
frequency_penalty: 0.1,
|
||||
presence_penalty: 0.2,
|
||||
frequency_penalty: Some(0.1),
|
||||
presence_penalty: Some(0.2),
|
||||
stop: None,
|
||||
top_k: 10,
|
||||
min_p: 0.05,
|
||||
@@ -493,7 +493,7 @@ fn test_reasoning_param_default() {
|
||||
#[test]
|
||||
fn test_json_serialization() {
|
||||
let request = ResponsesRequest {
|
||||
background: true,
|
||||
background: Some(true),
|
||||
include: None,
|
||||
input: ResponseInput::Text("Test input".to_string()),
|
||||
instructions: Some("Test instructions".to_string()),
|
||||
@@ -501,29 +501,29 @@ fn test_json_serialization() {
|
||||
max_tool_calls: Some(5),
|
||||
metadata: None,
|
||||
model: Some("gpt-4".to_string()),
|
||||
parallel_tool_calls: false,
|
||||
parallel_tool_calls: Some(false),
|
||||
previous_response_id: None,
|
||||
reasoning: Some(ResponseReasoningParam {
|
||||
effort: Some(ReasoningEffort::High),
|
||||
summary: None,
|
||||
}),
|
||||
service_tier: ServiceTier::Priority,
|
||||
store: false,
|
||||
stream: true,
|
||||
service_tier: Some(ServiceTier::Priority),
|
||||
store: Some(false),
|
||||
stream: Some(true),
|
||||
temperature: Some(0.9),
|
||||
tool_choice: ToolChoice::Value(ToolChoiceValue::Required),
|
||||
tools: vec![ResponseTool {
|
||||
tool_choice: Some(ToolChoice::Value(ToolChoiceValue::Required)),
|
||||
tools: Some(vec![ResponseTool {
|
||||
r#type: ResponseToolType::CodeInterpreter,
|
||||
..Default::default()
|
||||
}],
|
||||
top_logprobs: 10,
|
||||
}]),
|
||||
top_logprobs: Some(10),
|
||||
top_p: Some(0.8),
|
||||
truncation: Truncation::Auto,
|
||||
truncation: Some(Truncation::Auto),
|
||||
user: Some("test_user".to_string()),
|
||||
request_id: "resp_comprehensive_test".to_string(),
|
||||
priority: 1,
|
||||
frequency_penalty: 0.3,
|
||||
presence_penalty: 0.4,
|
||||
frequency_penalty: Some(0.3),
|
||||
presence_penalty: Some(0.4),
|
||||
stop: None,
|
||||
top_k: 50,
|
||||
min_p: 0.1,
|
||||
@@ -537,9 +537,9 @@ fn test_json_serialization() {
|
||||
|
||||
assert_eq!(parsed.request_id, "resp_comprehensive_test");
|
||||
assert_eq!(parsed.model, Some("gpt-4".to_string()));
|
||||
assert!(parsed.background);
|
||||
assert!(parsed.stream);
|
||||
assert_eq!(parsed.tools.len(), 1);
|
||||
assert_eq!(parsed.background, Some(true));
|
||||
assert_eq!(parsed.stream, Some(true));
|
||||
assert_eq!(parsed.tools.as_ref().map(|t| t.len()), Some(1));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
@@ -620,7 +620,7 @@ async fn test_multi_turn_loop_with_mcp() {
|
||||
|
||||
// Build request with MCP tools
|
||||
let req = ResponsesRequest {
|
||||
background: false,
|
||||
background: Some(false),
|
||||
include: None,
|
||||
input: ResponseInput::Text("search for SGLang".to_string()),
|
||||
instructions: Some("Be helpful".to_string()),
|
||||
@@ -628,30 +628,30 @@ async fn test_multi_turn_loop_with_mcp() {
|
||||
max_tool_calls: None, // No limit - test unlimited
|
||||
metadata: None,
|
||||
model: Some("mock-model".to_string()),
|
||||
parallel_tool_calls: true,
|
||||
parallel_tool_calls: Some(true),
|
||||
previous_response_id: None,
|
||||
reasoning: None,
|
||||
service_tier: ServiceTier::Auto,
|
||||
store: true,
|
||||
stream: false,
|
||||
service_tier: Some(ServiceTier::Auto),
|
||||
store: Some(true),
|
||||
stream: Some(false),
|
||||
temperature: Some(0.7),
|
||||
tool_choice: ToolChoice::Value(ToolChoiceValue::Auto),
|
||||
tools: vec![ResponseTool {
|
||||
tool_choice: Some(ToolChoice::Value(ToolChoiceValue::Auto)),
|
||||
tools: Some(vec![ResponseTool {
|
||||
r#type: ResponseToolType::Mcp,
|
||||
server_url: Some(mcp.url()),
|
||||
server_label: Some("mock".to_string()),
|
||||
server_description: Some("Mock MCP server for testing".to_string()),
|
||||
require_approval: Some("never".to_string()),
|
||||
..Default::default()
|
||||
}],
|
||||
top_logprobs: 0,
|
||||
}]),
|
||||
top_logprobs: Some(0),
|
||||
top_p: Some(1.0),
|
||||
truncation: Truncation::Disabled,
|
||||
truncation: Some(Truncation::Disabled),
|
||||
user: None,
|
||||
request_id: "resp_multi_turn_test".to_string(),
|
||||
priority: 0,
|
||||
frequency_penalty: 0.0,
|
||||
presence_penalty: 0.0,
|
||||
frequency_penalty: Some(0.0),
|
||||
presence_penalty: Some(0.0),
|
||||
stop: None,
|
||||
top_k: 50,
|
||||
min_p: 0.0,
|
||||
@@ -796,7 +796,7 @@ async fn test_max_tool_calls_limit() {
|
||||
.expect("router");
|
||||
|
||||
let req = ResponsesRequest {
|
||||
background: false,
|
||||
background: Some(false),
|
||||
include: None,
|
||||
input: ResponseInput::Text("test max calls".to_string()),
|
||||
instructions: None,
|
||||
@@ -804,28 +804,28 @@ async fn test_max_tool_calls_limit() {
|
||||
max_tool_calls: Some(1), // Limit to 1 call
|
||||
metadata: None,
|
||||
model: Some("mock-model".to_string()),
|
||||
parallel_tool_calls: true,
|
||||
parallel_tool_calls: Some(true),
|
||||
previous_response_id: None,
|
||||
reasoning: None,
|
||||
service_tier: ServiceTier::Auto,
|
||||
store: false,
|
||||
stream: false,
|
||||
service_tier: Some(ServiceTier::Auto),
|
||||
store: Some(false),
|
||||
stream: Some(false),
|
||||
temperature: Some(0.7),
|
||||
tool_choice: ToolChoice::Value(ToolChoiceValue::Auto),
|
||||
tools: vec![ResponseTool {
|
||||
tool_choice: Some(ToolChoice::Value(ToolChoiceValue::Auto)),
|
||||
tools: Some(vec![ResponseTool {
|
||||
r#type: ResponseToolType::Mcp,
|
||||
server_url: Some(mcp.url()),
|
||||
server_label: Some("mock".to_string()),
|
||||
..Default::default()
|
||||
}],
|
||||
top_logprobs: 0,
|
||||
}]),
|
||||
top_logprobs: Some(0),
|
||||
top_p: Some(1.0),
|
||||
truncation: Truncation::Disabled,
|
||||
truncation: Some(Truncation::Disabled),
|
||||
user: None,
|
||||
request_id: "resp_max_calls_test".to_string(),
|
||||
priority: 0,
|
||||
frequency_penalty: 0.0,
|
||||
presence_penalty: 0.0,
|
||||
frequency_penalty: Some(0.0),
|
||||
presence_penalty: Some(0.0),
|
||||
stop: None,
|
||||
top_k: 50,
|
||||
min_p: 0.0,
|
||||
@@ -990,7 +990,7 @@ async fn test_streaming_with_mcp_tool_calls() {
|
||||
|
||||
// Build streaming request with MCP tools
|
||||
let req = ResponsesRequest {
|
||||
background: false,
|
||||
background: Some(false),
|
||||
include: None,
|
||||
input: ResponseInput::Text("search for something interesting".to_string()),
|
||||
instructions: Some("Use tools when needed".to_string()),
|
||||
@@ -998,30 +998,30 @@ async fn test_streaming_with_mcp_tool_calls() {
|
||||
max_tool_calls: Some(3),
|
||||
metadata: None,
|
||||
model: Some("mock-model".to_string()),
|
||||
parallel_tool_calls: true,
|
||||
parallel_tool_calls: Some(true),
|
||||
previous_response_id: None,
|
||||
reasoning: None,
|
||||
service_tier: ServiceTier::Auto,
|
||||
store: true,
|
||||
stream: true, // KEY: Enable streaming
|
||||
service_tier: Some(ServiceTier::Auto),
|
||||
store: Some(true),
|
||||
stream: Some(true), // KEY: Enable streaming
|
||||
temperature: Some(0.7),
|
||||
tool_choice: ToolChoice::Value(ToolChoiceValue::Auto),
|
||||
tools: vec![ResponseTool {
|
||||
tool_choice: Some(ToolChoice::Value(ToolChoiceValue::Auto)),
|
||||
tools: Some(vec![ResponseTool {
|
||||
r#type: ResponseToolType::Mcp,
|
||||
server_url: Some(mcp.url()),
|
||||
server_label: Some("mock".to_string()),
|
||||
server_description: Some("Mock MCP for streaming test".to_string()),
|
||||
require_approval: Some("never".to_string()),
|
||||
..Default::default()
|
||||
}],
|
||||
top_logprobs: 0,
|
||||
}]),
|
||||
top_logprobs: Some(0),
|
||||
top_p: Some(1.0),
|
||||
truncation: Truncation::Disabled,
|
||||
truncation: Some(Truncation::Disabled),
|
||||
user: None,
|
||||
request_id: "resp_streaming_mcp_test".to_string(),
|
||||
priority: 0,
|
||||
frequency_penalty: 0.0,
|
||||
presence_penalty: 0.0,
|
||||
frequency_penalty: Some(0.0),
|
||||
presence_penalty: Some(0.0),
|
||||
stop: None,
|
||||
top_k: 50,
|
||||
min_p: 0.0,
|
||||
@@ -1271,7 +1271,7 @@ async fn test_streaming_multi_turn_with_mcp() {
|
||||
let (mut mcp, mut worker, router, _dir) = setup_streaming_mcp_test().await;
|
||||
|
||||
let req = ResponsesRequest {
|
||||
background: false,
|
||||
background: Some(false),
|
||||
include: None,
|
||||
input: ResponseInput::Text("complex query requiring multiple tool calls".to_string()),
|
||||
instructions: Some("Be thorough".to_string()),
|
||||
@@ -1279,28 +1279,28 @@ async fn test_streaming_multi_turn_with_mcp() {
|
||||
max_tool_calls: Some(5), // Allow multiple rounds
|
||||
metadata: None,
|
||||
model: Some("mock-model".to_string()),
|
||||
parallel_tool_calls: true,
|
||||
parallel_tool_calls: Some(true),
|
||||
previous_response_id: None,
|
||||
reasoning: None,
|
||||
service_tier: ServiceTier::Auto,
|
||||
store: true,
|
||||
stream: true,
|
||||
service_tier: Some(ServiceTier::Auto),
|
||||
store: Some(true),
|
||||
stream: Some(true),
|
||||
temperature: Some(0.8),
|
||||
tool_choice: ToolChoice::Value(ToolChoiceValue::Auto),
|
||||
tools: vec![ResponseTool {
|
||||
tool_choice: Some(ToolChoice::Value(ToolChoiceValue::Auto)),
|
||||
tools: Some(vec![ResponseTool {
|
||||
r#type: ResponseToolType::Mcp,
|
||||
server_url: Some(mcp.url()),
|
||||
server_label: Some("mock".to_string()),
|
||||
..Default::default()
|
||||
}],
|
||||
top_logprobs: 0,
|
||||
}]),
|
||||
top_logprobs: Some(0),
|
||||
top_p: Some(1.0),
|
||||
truncation: Truncation::Disabled,
|
||||
truncation: Some(Truncation::Disabled),
|
||||
user: None,
|
||||
request_id: "resp_streaming_multiturn_test".to_string(),
|
||||
priority: 0,
|
||||
frequency_penalty: 0.0,
|
||||
presence_penalty: 0.0,
|
||||
frequency_penalty: Some(0.0),
|
||||
presence_penalty: Some(0.0),
|
||||
stop: None,
|
||||
top_k: 50,
|
||||
min_p: 0.0,
|
||||
|
||||
@@ -234,7 +234,7 @@ async fn test_openai_router_responses_with_mock() {
|
||||
let request1 = ResponsesRequest {
|
||||
model: Some("gpt-4o-mini".to_string()),
|
||||
input: ResponseInput::Text("Say hi".to_string()),
|
||||
store: true,
|
||||
store: Some(true),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
@@ -250,7 +250,7 @@ async fn test_openai_router_responses_with_mock() {
|
||||
let request2 = ResponsesRequest {
|
||||
model: Some("gpt-4o-mini".to_string()),
|
||||
input: ResponseInput::Text("Thanks".to_string()),
|
||||
store: true,
|
||||
store: Some(true),
|
||||
previous_response_id: Some(resp1_id.clone()),
|
||||
..Default::default()
|
||||
};
|
||||
@@ -501,8 +501,8 @@ async fn test_openai_router_responses_streaming_with_mock() {
|
||||
instructions: Some("Be kind".to_string()),
|
||||
metadata: Some(metadata),
|
||||
previous_response_id: Some("resp_prev_chain".to_string()),
|
||||
store: true,
|
||||
stream: true,
|
||||
store: Some(true),
|
||||
stream: Some(true),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
|
||||
Reference in New Issue
Block a user