diff --git a/sgl-router/src/protocols/openai/mod.rs b/sgl-router/src/protocols/openai/mod.rs index 83c7ddfba..08495b92b 100644 --- a/sgl-router/src/protocols/openai/mod.rs +++ b/sgl-router/src/protocols/openai/mod.rs @@ -5,3 +5,4 @@ pub mod chat; pub mod common; pub mod completions; pub mod errors; +pub mod responses; diff --git a/sgl-router/src/protocols/openai/responses/mod.rs b/sgl-router/src/protocols/openai/responses/mod.rs new file mode 100644 index 000000000..e513116fd --- /dev/null +++ b/sgl-router/src/protocols/openai/responses/mod.rs @@ -0,0 +1,10 @@ +// Responses API module + +pub mod request; +pub mod response; +pub mod types; + +// Re-export main types for convenience +pub use request::ResponsesRequest; +pub use response::ResponsesResponse; +pub use types::*; diff --git a/sgl-router/src/protocols/openai/responses/request.rs b/sgl-router/src/protocols/openai/responses/request.rs new file mode 100644 index 000000000..575b487de --- /dev/null +++ b/sgl-router/src/protocols/openai/responses/request.rs @@ -0,0 +1,300 @@ +// Responses API request types + +use crate::protocols::common::{GenerationRequest, StringOrArray}; +use crate::protocols::openai::responses::types::*; +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; + +fn generate_request_id() -> String { + format!("resp_{}", uuid::Uuid::new_v4().simple()) +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct ResponsesRequest { + // ============= Core OpenAI API fields ============= + /// Run the request in the background + #[serde(default)] + pub background: bool, + + /// Fields to include in the response + #[serde(skip_serializing_if = "Option::is_none")] + pub include: Option>, + + /// Input content - can be string or structured items + pub input: ResponseInput, + + /// System instructions for the model + #[serde(skip_serializing_if = "Option::is_none")] + pub instructions: Option, + + /// Maximum number of output tokens + #[serde(skip_serializing_if = "Option::is_none")] + pub max_output_tokens: Option, + + /// Maximum number of tool calls + #[serde(skip_serializing_if = "Option::is_none")] + pub max_tool_calls: Option, + + /// Additional metadata + #[serde(skip_serializing_if = "Option::is_none")] + pub metadata: Option>, + + /// Model to use (optional to match vLLM) + #[serde(skip_serializing_if = "Option::is_none")] + pub model: Option, + + /// Whether to enable parallel tool calls + #[serde(default = "default_true")] + pub parallel_tool_calls: bool, + + /// ID of previous response to continue from + #[serde(skip_serializing_if = "Option::is_none")] + pub previous_response_id: Option, + + /// Reasoning configuration + #[serde(skip_serializing_if = "Option::is_none")] + pub reasoning: Option, + + /// Service tier + #[serde(default)] + pub service_tier: ServiceTier, + + /// Whether to store the response + #[serde(default = "default_true")] + pub store: bool, + + /// Whether to stream the response + #[serde(default)] + pub stream: bool, + + /// Temperature for sampling + #[serde(skip_serializing_if = "Option::is_none")] + pub temperature: Option, + + /// Tool choice behavior + #[serde(default)] + pub tool_choice: ToolChoice, + + /// Available tools + #[serde(default)] + pub tools: Vec, + + /// Number of top logprobs to return + #[serde(default)] + pub top_logprobs: u32, + + /// Top-p sampling parameter + #[serde(skip_serializing_if = "Option::is_none")] + pub top_p: Option, + + /// Truncation behavior + #[serde(default)] + pub truncation: Truncation, + + /// User identifier + #[serde(skip_serializing_if = "Option::is_none")] + pub user: Option, + + // ============= SGLang Extensions ============= + /// Request ID + #[serde(default = "generate_request_id")] + pub request_id: String, + + /// Request priority + #[serde(default)] + pub priority: i32, + + /// Frequency penalty + #[serde(default)] + pub frequency_penalty: f32, + + /// Presence penalty + #[serde(default)] + pub presence_penalty: f32, + + /// Stop sequences + #[serde(skip_serializing_if = "Option::is_none")] + pub stop: Option, + + /// Top-k sampling parameter + #[serde(default = "default_top_k")] + pub top_k: i32, + + /// Min-p sampling parameter + #[serde(default)] + pub min_p: f32, + + /// Repetition penalty + #[serde(default = "default_repetition_penalty")] + pub repetition_penalty: f32, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +#[serde(untagged)] +pub enum ResponseInput { + Text(String), + Items(Vec), +} + +fn default_top_k() -> i32 { + -1 +} + +fn default_repetition_penalty() -> f32 { + 1.0 +} + +fn default_true() -> bool { + true +} + +impl ResponsesRequest { + /// Default sampling parameters + const DEFAULT_TEMPERATURE: f32 = 0.7; + const DEFAULT_TOP_P: f32 = 1.0; + + /// Convert to sampling parameters for generation + pub fn to_sampling_params( + &self, + default_max_tokens: u32, + default_params: Option>, + ) -> HashMap { + let mut params = HashMap::new(); + + // Use max_output_tokens if available + let max_tokens = if let Some(max_output) = self.max_output_tokens { + std::cmp::min(max_output, default_max_tokens) + } else { + default_max_tokens + }; + + // Avoid exceeding context length by minus 1 token + let max_tokens = max_tokens.saturating_sub(1); + + // Temperature + let temperature = self.temperature.unwrap_or_else(|| { + default_params + .as_ref() + .and_then(|p| p.get("temperature")) + .and_then(|v| v.as_f64()) + .map(|v| v as f32) + .unwrap_or(Self::DEFAULT_TEMPERATURE) + }); + + // Top-p + let top_p = self.top_p.unwrap_or_else(|| { + default_params + .as_ref() + .and_then(|p| p.get("top_p")) + .and_then(|v| v.as_f64()) + .map(|v| v as f32) + .unwrap_or(Self::DEFAULT_TOP_P) + }); + + params.insert( + "max_new_tokens".to_string(), + serde_json::Value::Number(serde_json::Number::from(max_tokens)), + ); + params.insert( + "temperature".to_string(), + serde_json::Value::Number(serde_json::Number::from_f64(temperature as f64).unwrap()), + ); + params.insert( + "top_p".to_string(), + serde_json::Value::Number(serde_json::Number::from_f64(top_p as f64).unwrap()), + ); + params.insert( + "frequency_penalty".to_string(), + serde_json::Value::Number( + serde_json::Number::from_f64(self.frequency_penalty as f64).unwrap(), + ), + ); + params.insert( + "presence_penalty".to_string(), + serde_json::Value::Number( + serde_json::Number::from_f64(self.presence_penalty as f64).unwrap(), + ), + ); + params.insert( + "top_k".to_string(), + serde_json::Value::Number(serde_json::Number::from(self.top_k)), + ); + params.insert( + "min_p".to_string(), + serde_json::Value::Number(serde_json::Number::from_f64(self.min_p as f64).unwrap()), + ); + params.insert( + "repetition_penalty".to_string(), + serde_json::Value::Number( + serde_json::Number::from_f64(self.repetition_penalty as f64).unwrap(), + ), + ); + + if let Some(ref stop) = self.stop { + match serde_json::to_value(stop) { + Ok(value) => params.insert("stop".to_string(), value), + Err(_) => params.insert("stop".to_string(), serde_json::Value::Null), + }; + } + + // Apply any additional default parameters + if let Some(default_params) = default_params { + for (key, value) in default_params { + params.entry(key).or_insert(value); + } + } + + params + } +} + +impl GenerationRequest for ResponsesRequest { + fn is_stream(&self) -> bool { + self.stream + } + + fn get_model(&self) -> Option<&str> { + self.model.as_deref() + } + + fn extract_text_for_routing(&self) -> String { + match &self.input { + ResponseInput::Text(text) => text.clone(), + ResponseInput::Items(items) => items + .iter() + .filter_map(|item| match item { + ResponseInputOutputItem::Message { content, .. } => { + let texts: Vec = content + .iter() + .map(|part| match part { + ResponseContentPart::OutputText { text, .. } => text.clone(), + }) + .collect(); + if texts.is_empty() { + None + } else { + Some(texts.join(" ")) + } + } + ResponseInputOutputItem::Reasoning { content, .. } => { + let texts: Vec = content + .iter() + .map(|part| match part { + ResponseReasoningContent::ReasoningText { text } => text.clone(), + }) + .collect(); + if texts.is_empty() { + None + } else { + Some(texts.join(" ")) + } + } + ResponseInputOutputItem::FunctionToolCall { arguments, .. } => { + Some(arguments.clone()) + } + }) + .collect::>() + .join(" "), + } + } +} diff --git a/sgl-router/src/protocols/openai/responses/response.rs b/sgl-router/src/protocols/openai/responses/response.rs new file mode 100644 index 000000000..b124ce7d4 --- /dev/null +++ b/sgl-router/src/protocols/openai/responses/response.rs @@ -0,0 +1,280 @@ +// Responses API response types + +use crate::protocols::openai::responses::request::ResponsesRequest; +use crate::protocols::openai::responses::types::*; +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; + +fn generate_response_id() -> String { + format!("resp_{}", uuid::Uuid::new_v4().simple()) +} + +fn current_timestamp() -> i64 { + std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_else(|_| std::time::Duration::from_secs(0)) + .as_secs() as i64 +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct ResponsesResponse { + /// Response ID + #[serde(default = "generate_response_id")] + pub id: String, + + /// Object type + #[serde(default = "default_object_type")] + pub object: String, + + /// Creation timestamp + #[serde(default = "current_timestamp")] + pub created_at: i64, + + /// Model name + pub model: String, + + /// Output items + #[serde(default)] + pub output: Vec, + + /// Response status + pub status: ResponseStatus, + + /// Usage information + #[serde(skip_serializing_if = "Option::is_none")] + pub usage: Option, + + /// Whether parallel tool calls are enabled + #[serde(default = "default_true")] + pub parallel_tool_calls: bool, + + /// Tool choice setting + #[serde(default = "default_tool_choice")] + pub tool_choice: String, + + /// Available tools + #[serde(default)] + pub tools: Vec, +} + +fn default_object_type() -> String { + "response".to_string() +} + +fn default_true() -> bool { + true +} + +fn default_tool_choice() -> String { + "auto".to_string() +} + +impl ResponsesResponse { + /// Create a response from a request + #[allow(clippy::too_many_arguments)] + pub fn from_request( + request: &ResponsesRequest, + _sampling_params: &HashMap, + model_name: String, + created_time: i64, + output: Vec, + status: ResponseStatus, + usage: Option, + ) -> Self { + Self { + id: request.request_id.clone(), + object: "response".to_string(), + created_at: created_time, + model: model_name, + output, + status, + usage, + parallel_tool_calls: request.parallel_tool_calls, + tool_choice: match request.tool_choice { + ToolChoice::Auto => "auto".to_string(), + ToolChoice::Required => "required".to_string(), + ToolChoice::None => "none".to_string(), + }, + tools: request.tools.clone(), + } + } + + /// Create a new response with default values + pub fn new(request_id: String, model: String, status: ResponseStatus) -> Self { + Self { + id: request_id, + object: "response".to_string(), + created_at: current_timestamp(), + model, + output: Vec::new(), + status, + usage: None, + parallel_tool_calls: true, + tool_choice: "auto".to_string(), + tools: Vec::new(), + } + } + + /// Add an output item to the response + pub fn add_output(&mut self, item: ResponseOutputItem) { + self.output.push(item); + } + + /// Set the usage information + pub fn set_usage(&mut self, usage: UsageInfo) { + self.usage = Some(usage); + } + + /// Update the status + pub fn set_status(&mut self, status: ResponseStatus) { + self.status = status; + } + + /// Check if the response is complete + pub fn is_complete(&self) -> bool { + matches!(self.status, ResponseStatus::Completed) + } + + /// Check if the response is in progress + pub fn is_in_progress(&self) -> bool { + matches!(self.status, ResponseStatus::InProgress) + } + + /// Check if the response failed + pub fn is_failed(&self) -> bool { + matches!(self.status, ResponseStatus::Failed) + } + + /// Check if the response was cancelled + pub fn is_cancelled(&self) -> bool { + matches!(self.status, ResponseStatus::Cancelled) + } + + /// Check if the response is queued + pub fn is_queued(&self) -> bool { + matches!(self.status, ResponseStatus::Queued) + } + + /// Convert usage to OpenAI Responses API format + pub fn usage_in_response_format( + &self, + ) -> Option { + self.usage.as_ref().map(|usage| usage.to_response_usage()) + } + + /// Get the response as a JSON value with usage in response format + pub fn to_response_format(&self) -> serde_json::Value { + let mut response = serde_json::to_value(self).unwrap_or(serde_json::Value::Null); + + // Convert usage to response format if present + if let Some(usage) = &self.usage { + if let Ok(usage_value) = serde_json::to_value(usage.to_response_usage()) { + response["usage"] = usage_value; + } + } + + response + } +} + +// ============= Helper Functions ============= + +impl ResponseOutputItem { + /// Create a new message output item + pub fn new_message( + id: String, + role: String, + content: Vec, + status: String, + ) -> Self { + Self::Message { + id, + role, + content, + status, + } + } + + /// Create a new reasoning output item + pub fn new_reasoning( + id: String, + summary: Vec, + content: Vec, + status: Option, + ) -> Self { + Self::Reasoning { + id, + summary, + content, + status, + } + } + + /// Create a new function tool call output item + pub fn new_function_tool_call( + id: String, + name: String, + arguments: String, + output: Option, + status: String, + ) -> Self { + Self::FunctionToolCall { + id, + name, + arguments, + output, + status, + } + } +} + +impl ResponseContentPart { + /// Create a new text content part + pub fn new_text( + text: String, + annotations: Vec, + logprobs: Option, + ) -> Self { + Self::OutputText { + text, + annotations, + logprobs, + } + } +} + +impl ResponseReasoningContent { + /// Create a new reasoning text content + pub fn new_reasoning_text(text: String) -> Self { + Self::ReasoningText { text } + } +} + +impl UsageInfo { + /// Create a new usage info with token counts + pub fn new(prompt_tokens: u32, completion_tokens: u32, reasoning_tokens: Option) -> Self { + Self { + prompt_tokens, + completion_tokens, + total_tokens: prompt_tokens + completion_tokens, + reasoning_tokens, + prompt_tokens_details: None, + } + } + + /// Create usage info with cached token details + pub fn new_with_cached( + prompt_tokens: u32, + completion_tokens: u32, + reasoning_tokens: Option, + cached_tokens: u32, + ) -> Self { + Self { + prompt_tokens, + completion_tokens, + total_tokens: prompt_tokens + completion_tokens, + reasoning_tokens, + prompt_tokens_details: Some(PromptTokenUsageInfo { cached_tokens }), + } + } +} diff --git a/sgl-router/src/protocols/openai/responses/types.rs b/sgl-router/src/protocols/openai/responses/types.rs new file mode 100644 index 000000000..588772662 --- /dev/null +++ b/sgl-router/src/protocols/openai/responses/types.rs @@ -0,0 +1,296 @@ +// Supporting types for Responses API + +use crate::protocols::openai::common::ChatLogProbs; +use serde::{Deserialize, Serialize}; + +// ============= Tool Definitions ============= + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct ResponseTool { + #[serde(rename = "type")] + pub r#type: ResponseToolType, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +#[serde(rename_all = "snake_case")] +pub enum ResponseToolType { + WebSearchPreview, + CodeInterpreter, +} + +// ============= Reasoning Configuration ============= + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct ResponseReasoningParam { + #[serde(default = "default_reasoning_effort")] + pub effort: Option, +} + +fn default_reasoning_effort() -> Option { + Some(ReasoningEffort::Medium) +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +#[serde(rename_all = "snake_case")] +pub enum ReasoningEffort { + Low, + Medium, + High, +} + +// ============= Input/Output Items ============= + +#[derive(Debug, Clone, Deserialize, Serialize)] +#[serde(tag = "type")] +#[serde(rename_all = "snake_case")] +pub enum ResponseInputOutputItem { + #[serde(rename = "message")] + Message { + id: String, + role: String, + content: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + status: Option, + }, + #[serde(rename = "reasoning")] + Reasoning { + id: String, + #[serde(skip_serializing_if = "Vec::is_empty")] + summary: Vec, + content: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + status: Option, + }, + #[serde(rename = "function_tool_call")] + FunctionToolCall { + id: String, + name: String, + arguments: String, + #[serde(skip_serializing_if = "Option::is_none")] + output: Option, + #[serde(skip_serializing_if = "Option::is_none")] + status: Option, + }, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +#[serde(tag = "type")] +#[serde(rename_all = "snake_case")] +pub enum ResponseContentPart { + #[serde(rename = "output_text")] + OutputText { + text: String, + #[serde(skip_serializing_if = "Vec::is_empty")] + annotations: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + logprobs: Option, + }, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +#[serde(tag = "type")] +#[serde(rename_all = "snake_case")] +pub enum ResponseReasoningContent { + #[serde(rename = "reasoning_text")] + ReasoningText { text: String }, +} + +// ============= Output Items for Response ============= + +#[derive(Debug, Clone, Deserialize, Serialize)] +#[serde(tag = "type")] +#[serde(rename_all = "snake_case")] +pub enum ResponseOutputItem { + #[serde(rename = "message")] + Message { + id: String, + role: String, + content: Vec, + status: String, + }, + #[serde(rename = "reasoning")] + Reasoning { + id: String, + #[serde(skip_serializing_if = "Vec::is_empty")] + summary: Vec, + content: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + status: Option, + }, + #[serde(rename = "function_tool_call")] + FunctionToolCall { + id: String, + name: String, + arguments: String, + #[serde(skip_serializing_if = "Option::is_none")] + output: Option, + status: String, + }, +} + +// ============= Service Tier ============= + +#[derive(Debug, Clone, Deserialize, Serialize)] +#[serde(rename_all = "snake_case")] +pub enum ServiceTier { + Auto, + Default, + Flex, + Scale, + Priority, +} + +impl Default for ServiceTier { + fn default() -> Self { + Self::Auto + } +} + +// ============= Tool Choice ============= + +#[derive(Debug, Clone, Deserialize, Serialize)] +#[serde(rename_all = "snake_case")] +pub enum ToolChoice { + Auto, + Required, + None, +} + +impl Default for ToolChoice { + fn default() -> Self { + Self::Auto + } +} + +// ============= Truncation ============= + +#[derive(Debug, Clone, Deserialize, Serialize)] +#[serde(rename_all = "snake_case")] +pub enum Truncation { + Auto, + Disabled, +} + +impl Default for Truncation { + fn default() -> Self { + Self::Disabled + } +} + +// ============= Response Status ============= + +#[derive(Debug, Clone, Deserialize, Serialize)] +#[serde(rename_all = "snake_case")] +pub enum ResponseStatus { + Queued, + InProgress, + Completed, + Failed, + Cancelled, +} + +// ============= Include Fields ============= + +#[derive(Debug, Clone, Deserialize, Serialize)] +#[serde(rename_all = "snake_case")] +pub enum IncludeField { + #[serde(rename = "code_interpreter_call.outputs")] + CodeInterpreterCallOutputs, + #[serde(rename = "computer_call_output.output.image_url")] + ComputerCallOutputImageUrl, + #[serde(rename = "file_search_call.results")] + FileSearchCallResults, + #[serde(rename = "message.input_image.image_url")] + MessageInputImageUrl, + #[serde(rename = "message.output_text.logprobs")] + MessageOutputTextLogprobs, + #[serde(rename = "reasoning.encrypted_content")] + ReasoningEncryptedContent, +} + +// ============= Usage Info ============= + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct UsageInfo { + pub prompt_tokens: u32, + pub completion_tokens: u32, + pub total_tokens: u32, + #[serde(skip_serializing_if = "Option::is_none")] + pub reasoning_tokens: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub prompt_tokens_details: Option, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct PromptTokenUsageInfo { + pub cached_tokens: u32, +} + +// ============= Response Usage Format ============= + +/// OpenAI Responses API usage format (different from standard UsageInfo) +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct ResponseUsage { + pub input_tokens: u32, + pub output_tokens: u32, + pub total_tokens: u32, + #[serde(skip_serializing_if = "Option::is_none")] + pub input_tokens_details: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub output_tokens_details: Option, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct InputTokensDetails { + pub cached_tokens: u32, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct OutputTokensDetails { + pub reasoning_tokens: u32, +} + +impl UsageInfo { + /// Convert to OpenAI Responses API format + pub fn to_response_usage(&self) -> ResponseUsage { + ResponseUsage { + input_tokens: self.prompt_tokens, + output_tokens: self.completion_tokens, + total_tokens: self.total_tokens, + input_tokens_details: self.prompt_tokens_details.as_ref().map(|details| { + InputTokensDetails { + cached_tokens: details.cached_tokens, + } + }), + output_tokens_details: self.reasoning_tokens.map(|tokens| OutputTokensDetails { + reasoning_tokens: tokens, + }), + } + } +} + +impl From for ResponseUsage { + fn from(usage: UsageInfo) -> Self { + usage.to_response_usage() + } +} + +impl ResponseUsage { + /// Convert back to standard UsageInfo format + pub fn to_usage_info(&self) -> UsageInfo { + UsageInfo { + prompt_tokens: self.input_tokens, + completion_tokens: self.output_tokens, + total_tokens: self.total_tokens, + reasoning_tokens: self + .output_tokens_details + .as_ref() + .map(|details| details.reasoning_tokens), + prompt_tokens_details: self.input_tokens_details.as_ref().map(|details| { + PromptTokenUsageInfo { + cached_tokens: details.cached_tokens, + } + }), + } + } +} diff --git a/sgl-router/tests/responses_api_test.rs b/sgl-router/tests/responses_api_test.rs new file mode 100644 index 000000000..a5653edd8 --- /dev/null +++ b/sgl-router/tests/responses_api_test.rs @@ -0,0 +1,208 @@ +// Integration test for Responses API + +use sglang_router_rs::protocols::common::GenerationRequest; +use sglang_router_rs::protocols::openai::responses::request::ResponseInput; +use sglang_router_rs::protocols::openai::responses::*; + +#[test] +fn test_responses_request_creation() { + let request = ResponsesRequest { + background: false, + include: None, + input: ResponseInput::Text("Hello, world!".to_string()), + instructions: Some("Be helpful".to_string()), + max_output_tokens: Some(100), + max_tool_calls: None, + metadata: None, + model: Some("test-model".to_string()), + parallel_tool_calls: true, + previous_response_id: None, + reasoning: Some(ResponseReasoningParam { + effort: Some(ReasoningEffort::Medium), + }), + service_tier: ServiceTier::Auto, + store: true, + stream: false, + temperature: Some(0.7), + tool_choice: ToolChoice::Auto, + tools: vec![ResponseTool { + r#type: ResponseToolType::WebSearchPreview, + }], + top_logprobs: 5, + top_p: Some(0.9), + truncation: Truncation::Disabled, + user: Some("test-user".to_string()), + request_id: "resp_test123".to_string(), + priority: 0, + frequency_penalty: 0.0, + presence_penalty: 0.0, + stop: None, + top_k: -1, + min_p: 0.0, + repetition_penalty: 1.0, + }; + + // Test GenerationRequest trait implementation + assert!(!request.is_stream()); + assert_eq!(request.get_model(), Some("test-model")); + let routing_text = request.extract_text_for_routing(); + assert_eq!(routing_text, "Hello, world!"); +} + +#[test] +fn test_sampling_params_conversion() { + let request = ResponsesRequest { + background: false, + include: None, + input: ResponseInput::Text("Test".to_string()), + instructions: None, + max_output_tokens: Some(50), + max_tool_calls: None, + metadata: None, + model: Some("test-model".to_string()), + parallel_tool_calls: true, // Use default true + previous_response_id: None, + reasoning: None, + service_tier: ServiceTier::Auto, + store: true, // Use default true + stream: false, + temperature: Some(0.8), + tool_choice: ToolChoice::Auto, + tools: vec![], + top_logprobs: 0, // Use default 0 + top_p: Some(0.95), + truncation: Truncation::Auto, + user: None, + request_id: "resp_test456".to_string(), + priority: 0, + frequency_penalty: 0.1, + presence_penalty: 0.2, + stop: None, + top_k: 10, + min_p: 0.05, + repetition_penalty: 1.1, + }; + + let params = request.to_sampling_params(1000, None); + + // Check that parameters are converted correctly + assert!(params.contains_key("temperature")); + assert!(params.contains_key("top_p")); + assert!(params.contains_key("frequency_penalty")); + assert!(params.contains_key("max_new_tokens")); +} + +#[test] +fn test_responses_response_creation() { + let response = ResponsesResponse::new( + "resp_test789".to_string(), + "test-model".to_string(), + ResponseStatus::Completed, + ); + + assert_eq!(response.id, "resp_test789"); + assert_eq!(response.model, "test-model"); + assert!(response.is_complete()); + assert!(!response.is_in_progress()); + assert!(!response.is_failed()); +} + +#[test] +fn test_usage_conversion() { + let usage_info = UsageInfo::new_with_cached(15, 25, Some(8), 3); + let response_usage = usage_info.to_response_usage(); + + assert_eq!(response_usage.input_tokens, 15); + assert_eq!(response_usage.output_tokens, 25); + assert_eq!(response_usage.total_tokens, 40); + + // Check details are converted correctly + assert!(response_usage.input_tokens_details.is_some()); + assert_eq!( + response_usage + .input_tokens_details + .as_ref() + .unwrap() + .cached_tokens, + 3 + ); + + assert!(response_usage.output_tokens_details.is_some()); + assert_eq!( + response_usage + .output_tokens_details + .as_ref() + .unwrap() + .reasoning_tokens, + 8 + ); + + // Test reverse conversion + let back_to_usage = response_usage.to_usage_info(); + assert_eq!(back_to_usage.prompt_tokens, 15); + assert_eq!(back_to_usage.completion_tokens, 25); + assert_eq!(back_to_usage.reasoning_tokens, Some(8)); +} + +#[test] +fn test_reasoning_param_default() { + let param = ResponseReasoningParam { + effort: Some(ReasoningEffort::Medium), + }; + + // Test JSON serialization/deserialization preserves default + let json = serde_json::to_string(¶m).unwrap(); + let parsed: ResponseReasoningParam = serde_json::from_str(&json).unwrap(); + + assert!(matches!(parsed.effort, Some(ReasoningEffort::Medium))); +} + +#[test] +fn test_json_serialization() { + let request = ResponsesRequest { + background: true, + include: None, + input: ResponseInput::Text("Test input".to_string()), + instructions: Some("Test instructions".to_string()), + max_output_tokens: Some(200), + max_tool_calls: Some(5), + metadata: None, + model: Some("gpt-4".to_string()), + parallel_tool_calls: false, + previous_response_id: None, + reasoning: Some(ResponseReasoningParam { + effort: Some(ReasoningEffort::High), + }), + service_tier: ServiceTier::Priority, + store: false, + stream: true, + temperature: Some(0.9), + tool_choice: ToolChoice::Required, + tools: vec![ResponseTool { + r#type: ResponseToolType::CodeInterpreter, + }], + top_logprobs: 10, + top_p: Some(0.8), + truncation: Truncation::Auto, + user: Some("test_user".to_string()), + request_id: "resp_comprehensive_test".to_string(), + priority: 1, + frequency_penalty: 0.3, + presence_penalty: 0.4, + stop: None, + top_k: 50, + min_p: 0.1, + repetition_penalty: 1.2, + }; + + // Test that everything can be serialized to JSON and back + let json = serde_json::to_string(&request).expect("Serialization should work"); + let parsed: ResponsesRequest = + serde_json::from_str(&json).expect("Deserialization should work"); + + assert_eq!(parsed.request_id, "resp_comprehensive_test"); + assert_eq!(parsed.model, Some("gpt-4".to_string())); + assert!(parsed.background); + assert!(parsed.stream); + assert_eq!(parsed.tools.len(), 1); +}