[router] openai router: support grok model (#11511)

This commit is contained in:
Keyang Ru
2025-10-12 19:44:43 -07:00
committed by GitHub
parent a20e7df8d0
commit 63e84352b7
7 changed files with 248 additions and 184 deletions

View File

@@ -689,9 +689,13 @@ pub(super) async fn execute_tool_loop(
if state.total_calls > 0 {
let server_label = original_body
.tools
.iter()
.find(|t| matches!(t.r#type, ResponseToolType::Mcp))
.and_then(|t| t.server_label.as_deref())
.as_ref()
.and_then(|tools| {
tools
.iter()
.find(|t| matches!(t.r#type, ResponseToolType::Mcp))
.and_then(|t| t.server_label.as_deref())
})
.unwrap_or("mcp");
// Build mcp_list_tools item
@@ -747,9 +751,13 @@ pub(super) fn build_incomplete_response(
if let Some(output_array) = obj.get_mut("output").and_then(|v| v.as_array_mut()) {
let server_label = original_body
.tools
.iter()
.find(|t| matches!(t.r#type, ResponseToolType::Mcp))
.and_then(|t| t.server_label.as_deref())
.as_ref()
.and_then(|tools| {
tools
.iter()
.find(|t| matches!(t.r#type, ResponseToolType::Mcp))
.and_then(|t| t.server_label.as_deref())
})
.unwrap_or("mcp");
// Find any function_call items and convert them to mcp_call (incomplete)

View File

@@ -129,7 +129,10 @@ pub(super) fn patch_streaming_response_json(
}
}
obj.insert("store".to_string(), Value::Bool(original_body.store));
obj.insert(
"store".to_string(),
Value::Bool(original_body.store.unwrap_or(false)),
);
if obj
.get("model")
@@ -205,7 +208,7 @@ pub(super) fn rewrite_streaming_block(
let mut changed = false;
if let Some(response_obj) = parsed.get_mut("response").and_then(|v| v.as_object_mut()) {
let desired_store = Value::Bool(original_body.store);
let desired_store = Value::Bool(original_body.store.unwrap_or(false));
if response_obj.get("store") != Some(&desired_store) {
response_obj.insert("store".to_string(), desired_store);
changed = true;
@@ -267,10 +270,11 @@ pub(super) fn rewrite_streaming_block(
/// Mask function tools as MCP tools in response for client
pub(super) fn mask_tools_as_mcp(resp: &mut Value, original_body: &ResponsesRequest) {
let mcp_tool = original_body
.tools
.iter()
.find(|t| matches!(t.r#type, ResponseToolType::Mcp) && t.server_url.is_some());
let mcp_tool = original_body.tools.as_ref().and_then(|tools| {
tools
.iter()
.find(|t| matches!(t.r#type, ResponseToolType::Mcp) && t.server_url.is_some())
});
let Some(t) = mcp_tool else {
return;
};

View File

@@ -148,7 +148,11 @@ impl OpenAIRouter {
original_previous_response_id: Option<String>,
) -> Response {
// Check if MCP is active for this request
let req_mcp_manager = mcp_manager_from_request_tools(&original_body.tools).await;
let req_mcp_manager = if let Some(ref tools) = original_body.tools {
mcp_manager_from_request_tools(tools.as_slice()).await
} else {
None
};
let active_mcp = req_mcp_manager.as_ref().or(self.mcp_manager.as_ref());
let mut response_json: Value;
@@ -183,6 +187,7 @@ impl OpenAIRouter {
}
} else {
// No MCP - simple request
let mut request_builder = self.client.post(&url).json(&payload);
if let Some(h) = headers {
request_builder = apply_request_headers(h, request_builder, true);
@@ -385,6 +390,7 @@ impl crate::routers::RouterTrait for OpenAIRouter {
}
};
if let Some(obj) = payload.as_object_mut() {
// Always remove SGLang-specific fields (unsupported by OpenAI)
for key in [
"top_k",
"min_p",
@@ -535,7 +541,7 @@ impl crate::routers::RouterTrait for OpenAIRouter {
.into_response();
}
// Clone the body and override model if needed
// Clone the body for validation and logic, but we'll build payload differently
let mut request_body = body.clone();
if let Some(model) = model_id {
request_body.model = Some(model.to_string());
@@ -690,7 +696,7 @@ impl crate::routers::RouterTrait for OpenAIRouter {
}
// Always set store=false for upstream (we store internally)
request_body.store = false;
request_body.store = Some(false);
// Convert to JSON and strip SGLang-specific fields
let mut payload = match to_value(&request_body) {
@@ -704,14 +710,13 @@ impl crate::routers::RouterTrait for OpenAIRouter {
}
};
// Remove SGLang-specific fields
// Remove SGLang-specific fields only
if let Some(obj) = payload.as_object_mut() {
// Remove SGLang-specific fields (not part of OpenAI API)
for key in [
"request_id",
"priority",
"top_k",
"frequency_penalty",
"presence_penalty",
"min_p",
"min_tokens",
"regex",
@@ -732,10 +737,38 @@ impl crate::routers::RouterTrait for OpenAIRouter {
] {
obj.remove(key);
}
// XAI doesn't support the OPENAI item type input: https://platform.openai.com/docs/api-reference/responses/create#responses-create-input-input-item-list-item
// To Achieve XAI compatibility, strip extra fields from input messages (id, status)
// XAI doesn't support output_text as type for content with role of assistant
// so normalize content types: output_text -> input_text
if let Some(input_arr) = obj.get_mut("input").and_then(Value::as_array_mut) {
for item_obj in input_arr.iter_mut().filter_map(Value::as_object_mut) {
// Remove fields not universally supported
item_obj.remove("id");
item_obj.remove("status");
// Normalize content types to input_text (xAI compatibility)
if let Some(content_arr) =
item_obj.get_mut("content").and_then(Value::as_array_mut)
{
for content_obj in content_arr.iter_mut().filter_map(Value::as_object_mut) {
// Change output_text to input_text
if content_obj.get("type").and_then(Value::as_str)
== Some("output_text")
{
content_obj.insert(
"type".to_string(),
Value::String("input_text".to_string()),
);
}
}
}
}
}
}
// Delegate to streaming or non-streaming handler
if body.stream {
if body.stream.unwrap_or(false) {
handle_streaming_response(
&self.client,
&self.circuit_breaker,

View File

@@ -572,7 +572,7 @@ pub(super) fn apply_event_transformations_inplace(
.get_mut("response")
.and_then(|v| v.as_object_mut())
{
let desired_store = Value::Bool(original_request.store);
let desired_store = Value::Bool(original_request.store.unwrap_or(false));
if response_obj.get("store") != Some(&desired_store) {
response_obj.insert("store".to_string(), desired_store);
changed = true;
@@ -597,8 +597,13 @@ pub(super) fn apply_event_transformations_inplace(
if response_obj.get("tools").is_some() {
let requested_mcp = original_request
.tools
.iter()
.any(|t| matches!(t.r#type, ResponseToolType::Mcp));
.as_ref()
.map(|tools| {
tools
.iter()
.any(|t| matches!(t.r#type, ResponseToolType::Mcp))
})
.unwrap_or(false);
if requested_mcp {
if let Some(mcp_tools) = build_mcp_tools_value(original_request) {
@@ -658,8 +663,8 @@ pub(super) fn apply_event_transformations_inplace(
/// Helper to build MCP tools value
fn build_mcp_tools_value(original_body: &ResponsesRequest) -> Option<Value> {
let mcp_tool = original_body
.tools
let tools = original_body.tools.as_ref()?;
let mcp_tool = tools
.iter()
.find(|t| matches!(t.r#type, ResponseToolType::Mcp) && t.server_url.is_some())?;
@@ -1000,7 +1005,7 @@ pub(super) async fn handle_simple_streaming_passthrough(
let (tx, rx) = mpsc::unbounded_channel::<Result<Bytes, io::Error>>();
let should_store = original_body.store;
let should_store = original_body.store.unwrap_or(false);
let original_request = original_body.clone();
let persist_needed = original_request.conversation.is_some();
let previous_response_id = original_previous_response_id.clone();
@@ -1134,7 +1139,7 @@ pub(super) async fn handle_streaming_with_tool_interception(
prepare_mcp_payload_for_streaming(&mut payload, active_mcp);
let (tx, rx) = mpsc::unbounded_channel::<Result<Bytes, io::Error>>();
let should_store = original_body.store;
let should_store = original_body.store.unwrap_or(false);
let original_request = original_body.clone();
let persist_needed = original_request.conversation.is_some();
let previous_response_id = original_previous_response_id.clone();
@@ -1161,9 +1166,13 @@ pub(super) async fn handle_streaming_with_tool_interception(
let server_label = original_request
.tools
.iter()
.find(|t| matches!(t.r#type, ResponseToolType::Mcp))
.and_then(|t| t.server_label.as_deref())
.as_ref()
.and_then(|tools| {
tools
.iter()
.find(|t| matches!(t.r#type, ResponseToolType::Mcp))
.and_then(|t| t.server_label.as_deref())
})
.unwrap_or("mcp");
loop {
@@ -1488,7 +1497,11 @@ pub(super) async fn handle_streaming_response(
original_previous_response_id: Option<String>,
) -> Response {
// Check if MCP is active for this request
let req_mcp_manager = mcp_manager_from_request_tools(&original_body.tools).await;
let req_mcp_manager = if let Some(ref tools) = original_body.tools {
mcp_manager_from_request_tools(tools.as_slice()).await
} else {
None
};
let active_mcp = req_mcp_manager.as_ref().or(mcp_manager);
// If no MCP is active, use simple pass-through streaming