[router] openai router: support grok model (#11511)
This commit is contained in:
@@ -1073,8 +1073,8 @@ fn generate_request_id() -> String {
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct ResponsesRequest {
|
||||
/// Run the request in the background
|
||||
#[serde(default)]
|
||||
pub background: bool,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub background: Option<bool>,
|
||||
|
||||
/// Fields to include in the response
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
@@ -1108,8 +1108,8 @@ pub struct ResponsesRequest {
|
||||
pub conversation: Option<String>,
|
||||
|
||||
/// Whether to enable parallel tool calls
|
||||
#[serde(default = "default_true")]
|
||||
pub parallel_tool_calls: bool,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub parallel_tool_calls: Option<bool>,
|
||||
|
||||
/// ID of previous response to continue from
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
@@ -1120,40 +1120,40 @@ pub struct ResponsesRequest {
|
||||
pub reasoning: Option<ResponseReasoningParam>,
|
||||
|
||||
/// Service tier
|
||||
#[serde(default)]
|
||||
pub service_tier: ServiceTier,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub service_tier: Option<ServiceTier>,
|
||||
|
||||
/// Whether to store the response
|
||||
#[serde(default = "default_true")]
|
||||
pub store: bool,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub store: Option<bool>,
|
||||
|
||||
/// Whether to stream the response
|
||||
#[serde(default)]
|
||||
pub stream: bool,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub stream: Option<bool>,
|
||||
|
||||
/// Temperature for sampling
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub temperature: Option<f32>,
|
||||
|
||||
/// Tool choice behavior
|
||||
#[serde(default)]
|
||||
pub tool_choice: ToolChoice,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tool_choice: Option<ToolChoice>,
|
||||
|
||||
/// Available tools
|
||||
#[serde(default)]
|
||||
pub tools: Vec<ResponseTool>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tools: Option<Vec<ResponseTool>>,
|
||||
|
||||
/// Number of top logprobs to return
|
||||
#[serde(default)]
|
||||
pub top_logprobs: u32,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub top_logprobs: Option<u32>,
|
||||
|
||||
/// Top-p sampling parameter
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub top_p: Option<f32>,
|
||||
|
||||
/// Truncation behavior
|
||||
#[serde(default)]
|
||||
pub truncation: Truncation,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub truncation: Option<Truncation>,
|
||||
|
||||
/// User identifier
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
@@ -1168,12 +1168,12 @@ pub struct ResponsesRequest {
|
||||
pub priority: i32,
|
||||
|
||||
/// Frequency penalty
|
||||
#[serde(default)]
|
||||
pub frequency_penalty: f32,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub frequency_penalty: Option<f32>,
|
||||
|
||||
/// Presence penalty
|
||||
#[serde(default)]
|
||||
pub presence_penalty: f32,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub presence_penalty: Option<f32>,
|
||||
|
||||
/// Stop sequences
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
@@ -1210,7 +1210,7 @@ fn default_repetition_penalty() -> f32 {
|
||||
impl Default for ResponsesRequest {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
background: false,
|
||||
background: None,
|
||||
include: None,
|
||||
input: ResponseInput::Text(String::new()),
|
||||
instructions: None,
|
||||
@@ -1219,23 +1219,23 @@ impl Default for ResponsesRequest {
|
||||
metadata: None,
|
||||
model: None,
|
||||
conversation: None,
|
||||
parallel_tool_calls: true,
|
||||
parallel_tool_calls: None,
|
||||
previous_response_id: None,
|
||||
reasoning: None,
|
||||
service_tier: ServiceTier::default(),
|
||||
store: true,
|
||||
stream: false,
|
||||
service_tier: None,
|
||||
store: None,
|
||||
stream: None,
|
||||
temperature: None,
|
||||
tool_choice: ToolChoice::default(),
|
||||
tools: Vec::new(),
|
||||
top_logprobs: 0,
|
||||
tool_choice: None,
|
||||
tools: None,
|
||||
top_logprobs: None,
|
||||
top_p: None,
|
||||
truncation: Truncation::default(),
|
||||
truncation: None,
|
||||
user: None,
|
||||
request_id: generate_request_id(),
|
||||
priority: 0,
|
||||
frequency_penalty: 0.0,
|
||||
presence_penalty: 0.0,
|
||||
frequency_penalty: None,
|
||||
presence_penalty: None,
|
||||
stop: None,
|
||||
top_k: default_top_k(),
|
||||
min_p: 0.0,
|
||||
@@ -1299,14 +1299,18 @@ impl ResponsesRequest {
|
||||
"top_p".to_string(),
|
||||
Value::Number(Number::from_f64(top_p as f64).unwrap()),
|
||||
);
|
||||
params.insert(
|
||||
"frequency_penalty".to_string(),
|
||||
Value::Number(Number::from_f64(self.frequency_penalty as f64).unwrap()),
|
||||
);
|
||||
params.insert(
|
||||
"presence_penalty".to_string(),
|
||||
Value::Number(Number::from_f64(self.presence_penalty as f64).unwrap()),
|
||||
);
|
||||
if let Some(fp) = self.frequency_penalty {
|
||||
params.insert(
|
||||
"frequency_penalty".to_string(),
|
||||
Value::Number(Number::from_f64(fp as f64).unwrap()),
|
||||
);
|
||||
}
|
||||
if let Some(pp) = self.presence_penalty {
|
||||
params.insert(
|
||||
"presence_penalty".to_string(),
|
||||
Value::Number(Number::from_f64(pp as f64).unwrap()),
|
||||
);
|
||||
}
|
||||
params.insert("top_k".to_string(), Value::Number(Number::from(self.top_k)));
|
||||
params.insert(
|
||||
"min_p".to_string(),
|
||||
@@ -1337,7 +1341,7 @@ impl ResponsesRequest {
|
||||
|
||||
impl GenerationRequest for ResponsesRequest {
|
||||
fn is_stream(&self) -> bool {
|
||||
self.stream
|
||||
self.stream.unwrap_or(false)
|
||||
}
|
||||
|
||||
fn get_model(&self) -> Option<&str> {
|
||||
@@ -1523,13 +1527,13 @@ impl ResponsesResponse {
|
||||
max_output_tokens: request.max_output_tokens,
|
||||
model: model_name,
|
||||
output,
|
||||
parallel_tool_calls: request.parallel_tool_calls,
|
||||
parallel_tool_calls: request.parallel_tool_calls.unwrap_or(true),
|
||||
previous_response_id: request.previous_response_id.clone(),
|
||||
reasoning: request.reasoning.as_ref().map(|r| ReasoningInfo {
|
||||
effort: r.effort.as_ref().map(|e| format!("{:?}", e)),
|
||||
summary: None,
|
||||
}),
|
||||
store: request.store,
|
||||
store: request.store.unwrap_or(false),
|
||||
temperature: request.temperature,
|
||||
text: Some(ResponseTextFormat {
|
||||
format: TextFormatType {
|
||||
@@ -1537,17 +1541,19 @@ impl ResponsesResponse {
|
||||
},
|
||||
}),
|
||||
tool_choice: match &request.tool_choice {
|
||||
ToolChoice::Value(ToolChoiceValue::Auto) => "auto".to_string(),
|
||||
ToolChoice::Value(ToolChoiceValue::Required) => "required".to_string(),
|
||||
ToolChoice::Value(ToolChoiceValue::None) => "none".to_string(),
|
||||
ToolChoice::Function { .. } => "function".to_string(),
|
||||
ToolChoice::AllowedTools { mode, .. } => mode.clone(),
|
||||
Some(ToolChoice::Value(ToolChoiceValue::Auto)) => "auto".to_string(),
|
||||
Some(ToolChoice::Value(ToolChoiceValue::Required)) => "required".to_string(),
|
||||
Some(ToolChoice::Value(ToolChoiceValue::None)) => "none".to_string(),
|
||||
Some(ToolChoice::Function { .. }) => "function".to_string(),
|
||||
Some(ToolChoice::AllowedTools { mode, .. }) => mode.clone(),
|
||||
None => "auto".to_string(),
|
||||
},
|
||||
tools: request.tools.clone(),
|
||||
tools: request.tools.clone().unwrap_or_default(),
|
||||
top_p: request.top_p,
|
||||
truncation: match &request.truncation {
|
||||
Truncation::Auto => Some("auto".to_string()),
|
||||
Truncation::Disabled => Some("disabled".to_string()),
|
||||
Some(Truncation::Auto) => Some("auto".to_string()),
|
||||
Some(Truncation::Disabled) => Some("disabled".to_string()),
|
||||
None => None,
|
||||
},
|
||||
usage: usage.map(ResponsesUsage::Classic),
|
||||
user: request.user.clone(),
|
||||
|
||||
Reference in New Issue
Block a user