[router] openai router: support grok model (#11511)

This commit is contained in:
Keyang Ru
2025-10-12 19:44:43 -07:00
committed by GitHub
parent a20e7df8d0
commit 63e84352b7
7 changed files with 248 additions and 184 deletions

View File

@@ -89,7 +89,7 @@ async fn test_non_streaming_mcp_minimal_e2e_with_persistence() {
// Build a simple ResponsesRequest that will trigger the tool call
let req = ResponsesRequest {
background: false,
background: Some(false),
include: None,
input: ResponseInput::Text("search something".to_string()),
instructions: Some("Be brief".to_string()),
@@ -97,15 +97,15 @@ async fn test_non_streaming_mcp_minimal_e2e_with_persistence() {
max_tool_calls: None,
metadata: None,
model: Some("mock-model".to_string()),
parallel_tool_calls: true,
parallel_tool_calls: Some(true),
previous_response_id: None,
reasoning: None,
service_tier: ServiceTier::Auto,
store: true,
stream: false,
service_tier: Some(ServiceTier::Auto),
store: Some(true),
stream: Some(false),
temperature: Some(0.2),
tool_choice: ToolChoice::default(),
tools: vec![ResponseTool {
tool_choice: Some(ToolChoice::default()),
tools: Some(vec![ResponseTool {
r#type: ResponseToolType::Mcp,
server_url: Some(mcp.url()),
authorization: None,
@@ -113,15 +113,15 @@ async fn test_non_streaming_mcp_minimal_e2e_with_persistence() {
server_description: None,
require_approval: None,
allowed_tools: None,
}],
top_logprobs: 0,
}]),
top_logprobs: Some(0),
top_p: None,
truncation: Truncation::Disabled,
truncation: Some(Truncation::Disabled),
user: None,
request_id: "resp_test_mcp_e2e".to_string(),
priority: 0,
frequency_penalty: 0.0,
presence_penalty: 0.0,
frequency_penalty: Some(0.0),
presence_penalty: Some(0.0),
stop: None,
top_k: -1,
min_p: 0.0,
@@ -338,7 +338,7 @@ async fn test_conversations_crud_basic() {
#[test]
fn test_responses_request_creation() {
let request = ResponsesRequest {
background: false,
background: Some(false),
include: None,
input: ResponseInput::Text("Hello, world!".to_string()),
instructions: Some("Be helpful".to_string()),
@@ -346,29 +346,29 @@ fn test_responses_request_creation() {
max_tool_calls: None,
metadata: None,
model: Some("test-model".to_string()),
parallel_tool_calls: true,
parallel_tool_calls: Some(true),
previous_response_id: None,
reasoning: Some(ResponseReasoningParam {
effort: Some(ReasoningEffort::Medium),
summary: None,
}),
service_tier: ServiceTier::Auto,
store: true,
stream: false,
service_tier: Some(ServiceTier::Auto),
store: Some(true),
stream: Some(false),
temperature: Some(0.7),
tool_choice: ToolChoice::Value(ToolChoiceValue::Auto),
tools: vec![ResponseTool {
tool_choice: Some(ToolChoice::Value(ToolChoiceValue::Auto)),
tools: Some(vec![ResponseTool {
r#type: ResponseToolType::WebSearchPreview,
..Default::default()
}],
top_logprobs: 5,
}]),
top_logprobs: Some(5),
top_p: Some(0.9),
truncation: Truncation::Disabled,
truncation: Some(Truncation::Disabled),
user: Some("test-user".to_string()),
request_id: "resp_test123".to_string(),
priority: 0,
frequency_penalty: 0.0,
presence_penalty: 0.0,
frequency_penalty: Some(0.0),
presence_penalty: Some(0.0),
stop: None,
top_k: -1,
min_p: 0.0,
@@ -385,7 +385,7 @@ fn test_responses_request_creation() {
#[test]
fn test_sampling_params_conversion() {
let request = ResponsesRequest {
background: false,
background: Some(false),
include: None,
input: ResponseInput::Text("Test".to_string()),
instructions: None,
@@ -393,23 +393,23 @@ fn test_sampling_params_conversion() {
max_tool_calls: None,
metadata: None,
model: Some("test-model".to_string()),
parallel_tool_calls: true, // Use default true
parallel_tool_calls: Some(true), // Use default true
previous_response_id: None,
reasoning: None,
service_tier: ServiceTier::Auto,
store: true, // Use default true
stream: false,
service_tier: Some(ServiceTier::Auto),
store: Some(true), // Use default true
stream: Some(false),
temperature: Some(0.8),
tool_choice: ToolChoice::Value(ToolChoiceValue::Auto),
tools: vec![],
top_logprobs: 0, // Use default 0
tool_choice: Some(ToolChoice::Value(ToolChoiceValue::Auto)),
tools: Some(vec![]),
top_logprobs: Some(0), // Use default 0
top_p: Some(0.95),
truncation: Truncation::Auto,
truncation: Some(Truncation::Auto),
user: None,
request_id: "resp_test456".to_string(),
priority: 0,
frequency_penalty: 0.1,
presence_penalty: 0.2,
frequency_penalty: Some(0.1),
presence_penalty: Some(0.2),
stop: None,
top_k: 10,
min_p: 0.05,
@@ -493,7 +493,7 @@ fn test_reasoning_param_default() {
#[test]
fn test_json_serialization() {
let request = ResponsesRequest {
background: true,
background: Some(true),
include: None,
input: ResponseInput::Text("Test input".to_string()),
instructions: Some("Test instructions".to_string()),
@@ -501,29 +501,29 @@ fn test_json_serialization() {
max_tool_calls: Some(5),
metadata: None,
model: Some("gpt-4".to_string()),
parallel_tool_calls: false,
parallel_tool_calls: Some(false),
previous_response_id: None,
reasoning: Some(ResponseReasoningParam {
effort: Some(ReasoningEffort::High),
summary: None,
}),
service_tier: ServiceTier::Priority,
store: false,
stream: true,
service_tier: Some(ServiceTier::Priority),
store: Some(false),
stream: Some(true),
temperature: Some(0.9),
tool_choice: ToolChoice::Value(ToolChoiceValue::Required),
tools: vec![ResponseTool {
tool_choice: Some(ToolChoice::Value(ToolChoiceValue::Required)),
tools: Some(vec![ResponseTool {
r#type: ResponseToolType::CodeInterpreter,
..Default::default()
}],
top_logprobs: 10,
}]),
top_logprobs: Some(10),
top_p: Some(0.8),
truncation: Truncation::Auto,
truncation: Some(Truncation::Auto),
user: Some("test_user".to_string()),
request_id: "resp_comprehensive_test".to_string(),
priority: 1,
frequency_penalty: 0.3,
presence_penalty: 0.4,
frequency_penalty: Some(0.3),
presence_penalty: Some(0.4),
stop: None,
top_k: 50,
min_p: 0.1,
@@ -537,9 +537,9 @@ fn test_json_serialization() {
assert_eq!(parsed.request_id, "resp_comprehensive_test");
assert_eq!(parsed.model, Some("gpt-4".to_string()));
assert!(parsed.background);
assert!(parsed.stream);
assert_eq!(parsed.tools.len(), 1);
assert_eq!(parsed.background, Some(true));
assert_eq!(parsed.stream, Some(true));
assert_eq!(parsed.tools.as_ref().map(|t| t.len()), Some(1));
}
#[tokio::test]
@@ -620,7 +620,7 @@ async fn test_multi_turn_loop_with_mcp() {
// Build request with MCP tools
let req = ResponsesRequest {
background: false,
background: Some(false),
include: None,
input: ResponseInput::Text("search for SGLang".to_string()),
instructions: Some("Be helpful".to_string()),
@@ -628,30 +628,30 @@ async fn test_multi_turn_loop_with_mcp() {
max_tool_calls: None, // No limit - test unlimited
metadata: None,
model: Some("mock-model".to_string()),
parallel_tool_calls: true,
parallel_tool_calls: Some(true),
previous_response_id: None,
reasoning: None,
service_tier: ServiceTier::Auto,
store: true,
stream: false,
service_tier: Some(ServiceTier::Auto),
store: Some(true),
stream: Some(false),
temperature: Some(0.7),
tool_choice: ToolChoice::Value(ToolChoiceValue::Auto),
tools: vec![ResponseTool {
tool_choice: Some(ToolChoice::Value(ToolChoiceValue::Auto)),
tools: Some(vec![ResponseTool {
r#type: ResponseToolType::Mcp,
server_url: Some(mcp.url()),
server_label: Some("mock".to_string()),
server_description: Some("Mock MCP server for testing".to_string()),
require_approval: Some("never".to_string()),
..Default::default()
}],
top_logprobs: 0,
}]),
top_logprobs: Some(0),
top_p: Some(1.0),
truncation: Truncation::Disabled,
truncation: Some(Truncation::Disabled),
user: None,
request_id: "resp_multi_turn_test".to_string(),
priority: 0,
frequency_penalty: 0.0,
presence_penalty: 0.0,
frequency_penalty: Some(0.0),
presence_penalty: Some(0.0),
stop: None,
top_k: 50,
min_p: 0.0,
@@ -796,7 +796,7 @@ async fn test_max_tool_calls_limit() {
.expect("router");
let req = ResponsesRequest {
background: false,
background: Some(false),
include: None,
input: ResponseInput::Text("test max calls".to_string()),
instructions: None,
@@ -804,28 +804,28 @@ async fn test_max_tool_calls_limit() {
max_tool_calls: Some(1), // Limit to 1 call
metadata: None,
model: Some("mock-model".to_string()),
parallel_tool_calls: true,
parallel_tool_calls: Some(true),
previous_response_id: None,
reasoning: None,
service_tier: ServiceTier::Auto,
store: false,
stream: false,
service_tier: Some(ServiceTier::Auto),
store: Some(false),
stream: Some(false),
temperature: Some(0.7),
tool_choice: ToolChoice::Value(ToolChoiceValue::Auto),
tools: vec![ResponseTool {
tool_choice: Some(ToolChoice::Value(ToolChoiceValue::Auto)),
tools: Some(vec![ResponseTool {
r#type: ResponseToolType::Mcp,
server_url: Some(mcp.url()),
server_label: Some("mock".to_string()),
..Default::default()
}],
top_logprobs: 0,
}]),
top_logprobs: Some(0),
top_p: Some(1.0),
truncation: Truncation::Disabled,
truncation: Some(Truncation::Disabled),
user: None,
request_id: "resp_max_calls_test".to_string(),
priority: 0,
frequency_penalty: 0.0,
presence_penalty: 0.0,
frequency_penalty: Some(0.0),
presence_penalty: Some(0.0),
stop: None,
top_k: 50,
min_p: 0.0,
@@ -990,7 +990,7 @@ async fn test_streaming_with_mcp_tool_calls() {
// Build streaming request with MCP tools
let req = ResponsesRequest {
background: false,
background: Some(false),
include: None,
input: ResponseInput::Text("search for something interesting".to_string()),
instructions: Some("Use tools when needed".to_string()),
@@ -998,30 +998,30 @@ async fn test_streaming_with_mcp_tool_calls() {
max_tool_calls: Some(3),
metadata: None,
model: Some("mock-model".to_string()),
parallel_tool_calls: true,
parallel_tool_calls: Some(true),
previous_response_id: None,
reasoning: None,
service_tier: ServiceTier::Auto,
store: true,
stream: true, // KEY: Enable streaming
service_tier: Some(ServiceTier::Auto),
store: Some(true),
stream: Some(true), // KEY: Enable streaming
temperature: Some(0.7),
tool_choice: ToolChoice::Value(ToolChoiceValue::Auto),
tools: vec![ResponseTool {
tool_choice: Some(ToolChoice::Value(ToolChoiceValue::Auto)),
tools: Some(vec![ResponseTool {
r#type: ResponseToolType::Mcp,
server_url: Some(mcp.url()),
server_label: Some("mock".to_string()),
server_description: Some("Mock MCP for streaming test".to_string()),
require_approval: Some("never".to_string()),
..Default::default()
}],
top_logprobs: 0,
}]),
top_logprobs: Some(0),
top_p: Some(1.0),
truncation: Truncation::Disabled,
truncation: Some(Truncation::Disabled),
user: None,
request_id: "resp_streaming_mcp_test".to_string(),
priority: 0,
frequency_penalty: 0.0,
presence_penalty: 0.0,
frequency_penalty: Some(0.0),
presence_penalty: Some(0.0),
stop: None,
top_k: 50,
min_p: 0.0,
@@ -1271,7 +1271,7 @@ async fn test_streaming_multi_turn_with_mcp() {
let (mut mcp, mut worker, router, _dir) = setup_streaming_mcp_test().await;
let req = ResponsesRequest {
background: false,
background: Some(false),
include: None,
input: ResponseInput::Text("complex query requiring multiple tool calls".to_string()),
instructions: Some("Be thorough".to_string()),
@@ -1279,28 +1279,28 @@ async fn test_streaming_multi_turn_with_mcp() {
max_tool_calls: Some(5), // Allow multiple rounds
metadata: None,
model: Some("mock-model".to_string()),
parallel_tool_calls: true,
parallel_tool_calls: Some(true),
previous_response_id: None,
reasoning: None,
service_tier: ServiceTier::Auto,
store: true,
stream: true,
service_tier: Some(ServiceTier::Auto),
store: Some(true),
stream: Some(true),
temperature: Some(0.8),
tool_choice: ToolChoice::Value(ToolChoiceValue::Auto),
tools: vec![ResponseTool {
tool_choice: Some(ToolChoice::Value(ToolChoiceValue::Auto)),
tools: Some(vec![ResponseTool {
r#type: ResponseToolType::Mcp,
server_url: Some(mcp.url()),
server_label: Some("mock".to_string()),
..Default::default()
}],
top_logprobs: 0,
}]),
top_logprobs: Some(0),
top_p: Some(1.0),
truncation: Truncation::Disabled,
truncation: Some(Truncation::Disabled),
user: None,
request_id: "resp_streaming_multiturn_test".to_string(),
priority: 0,
frequency_penalty: 0.0,
presence_penalty: 0.0,
frequency_penalty: Some(0.0),
presence_penalty: Some(0.0),
stop: None,
top_k: 50,
min_p: 0.0,