[router] openai router: support grok model (#11511)

This commit is contained in:
Keyang Ru
2025-10-12 19:44:43 -07:00
committed by GitHub
parent a20e7df8d0
commit 63e84352b7
7 changed files with 248 additions and 184 deletions

View File

@@ -1073,8 +1073,8 @@ fn generate_request_id() -> String {
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ResponsesRequest {
/// Run the request in the background
#[serde(default)]
pub background: bool,
#[serde(skip_serializing_if = "Option::is_none")]
pub background: Option<bool>,
/// Fields to include in the response
#[serde(skip_serializing_if = "Option::is_none")]
@@ -1108,8 +1108,8 @@ pub struct ResponsesRequest {
pub conversation: Option<String>,
/// Whether to enable parallel tool calls
#[serde(default = "default_true")]
pub parallel_tool_calls: bool,
#[serde(skip_serializing_if = "Option::is_none")]
pub parallel_tool_calls: Option<bool>,
/// ID of previous response to continue from
#[serde(skip_serializing_if = "Option::is_none")]
@@ -1120,40 +1120,40 @@ pub struct ResponsesRequest {
pub reasoning: Option<ResponseReasoningParam>,
/// Service tier
#[serde(default)]
pub service_tier: ServiceTier,
#[serde(skip_serializing_if = "Option::is_none")]
pub service_tier: Option<ServiceTier>,
/// Whether to store the response
#[serde(default = "default_true")]
pub store: bool,
#[serde(skip_serializing_if = "Option::is_none")]
pub store: Option<bool>,
/// Whether to stream the response
#[serde(default)]
pub stream: bool,
#[serde(skip_serializing_if = "Option::is_none")]
pub stream: Option<bool>,
/// Temperature for sampling
#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f32>,
/// Tool choice behavior
#[serde(default)]
pub tool_choice: ToolChoice,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_choice: Option<ToolChoice>,
/// Available tools
#[serde(default)]
pub tools: Vec<ResponseTool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tools: Option<Vec<ResponseTool>>,
/// Number of top logprobs to return
#[serde(default)]
pub top_logprobs: u32,
#[serde(skip_serializing_if = "Option::is_none")]
pub top_logprobs: Option<u32>,
/// Top-p sampling parameter
#[serde(skip_serializing_if = "Option::is_none")]
pub top_p: Option<f32>,
/// Truncation behavior
#[serde(default)]
pub truncation: Truncation,
#[serde(skip_serializing_if = "Option::is_none")]
pub truncation: Option<Truncation>,
/// User identifier
#[serde(skip_serializing_if = "Option::is_none")]
@@ -1168,12 +1168,12 @@ pub struct ResponsesRequest {
pub priority: i32,
/// Frequency penalty
#[serde(default)]
pub frequency_penalty: f32,
#[serde(skip_serializing_if = "Option::is_none")]
pub frequency_penalty: Option<f32>,
/// Presence penalty
#[serde(default)]
pub presence_penalty: f32,
#[serde(skip_serializing_if = "Option::is_none")]
pub presence_penalty: Option<f32>,
/// Stop sequences
#[serde(skip_serializing_if = "Option::is_none")]
@@ -1210,7 +1210,7 @@ fn default_repetition_penalty() -> f32 {
impl Default for ResponsesRequest {
fn default() -> Self {
Self {
background: false,
background: None,
include: None,
input: ResponseInput::Text(String::new()),
instructions: None,
@@ -1219,23 +1219,23 @@ impl Default for ResponsesRequest {
metadata: None,
model: None,
conversation: None,
parallel_tool_calls: true,
parallel_tool_calls: None,
previous_response_id: None,
reasoning: None,
service_tier: ServiceTier::default(),
store: true,
stream: false,
service_tier: None,
store: None,
stream: None,
temperature: None,
tool_choice: ToolChoice::default(),
tools: Vec::new(),
top_logprobs: 0,
tool_choice: None,
tools: None,
top_logprobs: None,
top_p: None,
truncation: Truncation::default(),
truncation: None,
user: None,
request_id: generate_request_id(),
priority: 0,
frequency_penalty: 0.0,
presence_penalty: 0.0,
frequency_penalty: None,
presence_penalty: None,
stop: None,
top_k: default_top_k(),
min_p: 0.0,
@@ -1299,14 +1299,18 @@ impl ResponsesRequest {
"top_p".to_string(),
Value::Number(Number::from_f64(top_p as f64).unwrap()),
);
params.insert(
"frequency_penalty".to_string(),
Value::Number(Number::from_f64(self.frequency_penalty as f64).unwrap()),
);
params.insert(
"presence_penalty".to_string(),
Value::Number(Number::from_f64(self.presence_penalty as f64).unwrap()),
);
if let Some(fp) = self.frequency_penalty {
params.insert(
"frequency_penalty".to_string(),
Value::Number(Number::from_f64(fp as f64).unwrap()),
);
}
if let Some(pp) = self.presence_penalty {
params.insert(
"presence_penalty".to_string(),
Value::Number(Number::from_f64(pp as f64).unwrap()),
);
}
params.insert("top_k".to_string(), Value::Number(Number::from(self.top_k)));
params.insert(
"min_p".to_string(),
@@ -1337,7 +1341,7 @@ impl ResponsesRequest {
impl GenerationRequest for ResponsesRequest {
fn is_stream(&self) -> bool {
self.stream
self.stream.unwrap_or(false)
}
fn get_model(&self) -> Option<&str> {
@@ -1523,13 +1527,13 @@ impl ResponsesResponse {
max_output_tokens: request.max_output_tokens,
model: model_name,
output,
parallel_tool_calls: request.parallel_tool_calls,
parallel_tool_calls: request.parallel_tool_calls.unwrap_or(true),
previous_response_id: request.previous_response_id.clone(),
reasoning: request.reasoning.as_ref().map(|r| ReasoningInfo {
effort: r.effort.as_ref().map(|e| format!("{:?}", e)),
summary: None,
}),
store: request.store,
store: request.store.unwrap_or(false),
temperature: request.temperature,
text: Some(ResponseTextFormat {
format: TextFormatType {
@@ -1537,17 +1541,19 @@ impl ResponsesResponse {
},
}),
tool_choice: match &request.tool_choice {
ToolChoice::Value(ToolChoiceValue::Auto) => "auto".to_string(),
ToolChoice::Value(ToolChoiceValue::Required) => "required".to_string(),
ToolChoice::Value(ToolChoiceValue::None) => "none".to_string(),
ToolChoice::Function { .. } => "function".to_string(),
ToolChoice::AllowedTools { mode, .. } => mode.clone(),
Some(ToolChoice::Value(ToolChoiceValue::Auto)) => "auto".to_string(),
Some(ToolChoice::Value(ToolChoiceValue::Required)) => "required".to_string(),
Some(ToolChoice::Value(ToolChoiceValue::None)) => "none".to_string(),
Some(ToolChoice::Function { .. }) => "function".to_string(),
Some(ToolChoice::AllowedTools { mode, .. }) => mode.clone(),
None => "auto".to_string(),
},
tools: request.tools.clone(),
tools: request.tools.clone().unwrap_or_default(),
top_p: request.top_p,
truncation: match &request.truncation {
Truncation::Auto => Some("auto".to_string()),
Truncation::Disabled => Some("disabled".to_string()),
Some(Truncation::Auto) => Some("auto".to_string()),
Some(Truncation::Disabled) => Some("disabled".to_string()),
None => None,
},
usage: usage.map(ResponsesUsage::Classic),
user: request.user.clone(),