[router] Implement OpenAI Responses API specification (#9367)
This commit is contained in:
@@ -5,3 +5,4 @@ pub mod chat;
|
||||
pub mod common;
|
||||
pub mod completions;
|
||||
pub mod errors;
|
||||
pub mod responses;
|
||||
|
||||
10
sgl-router/src/protocols/openai/responses/mod.rs
Normal file
10
sgl-router/src/protocols/openai/responses/mod.rs
Normal file
@@ -0,0 +1,10 @@
|
||||
// Responses API module
|
||||
|
||||
pub mod request;
|
||||
pub mod response;
|
||||
pub mod types;
|
||||
|
||||
// Re-export main types for convenience
|
||||
pub use request::ResponsesRequest;
|
||||
pub use response::ResponsesResponse;
|
||||
pub use types::*;
|
||||
300
sgl-router/src/protocols/openai/responses/request.rs
Normal file
300
sgl-router/src/protocols/openai/responses/request.rs
Normal file
@@ -0,0 +1,300 @@
|
||||
// Responses API request types
|
||||
|
||||
use crate::protocols::common::{GenerationRequest, StringOrArray};
|
||||
use crate::protocols::openai::responses::types::*;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
|
||||
fn generate_request_id() -> String {
|
||||
format!("resp_{}", uuid::Uuid::new_v4().simple())
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct ResponsesRequest {
|
||||
// ============= Core OpenAI API fields =============
|
||||
/// Run the request in the background
|
||||
#[serde(default)]
|
||||
pub background: bool,
|
||||
|
||||
/// Fields to include in the response
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub include: Option<Vec<IncludeField>>,
|
||||
|
||||
/// Input content - can be string or structured items
|
||||
pub input: ResponseInput,
|
||||
|
||||
/// System instructions for the model
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub instructions: Option<String>,
|
||||
|
||||
/// Maximum number of output tokens
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub max_output_tokens: Option<u32>,
|
||||
|
||||
/// Maximum number of tool calls
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub max_tool_calls: Option<u32>,
|
||||
|
||||
/// Additional metadata
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub metadata: Option<HashMap<String, serde_json::Value>>,
|
||||
|
||||
/// Model to use (optional to match vLLM)
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub model: Option<String>,
|
||||
|
||||
/// Whether to enable parallel tool calls
|
||||
#[serde(default = "default_true")]
|
||||
pub parallel_tool_calls: bool,
|
||||
|
||||
/// ID of previous response to continue from
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub previous_response_id: Option<String>,
|
||||
|
||||
/// Reasoning configuration
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub reasoning: Option<ResponseReasoningParam>,
|
||||
|
||||
/// Service tier
|
||||
#[serde(default)]
|
||||
pub service_tier: ServiceTier,
|
||||
|
||||
/// Whether to store the response
|
||||
#[serde(default = "default_true")]
|
||||
pub store: bool,
|
||||
|
||||
/// Whether to stream the response
|
||||
#[serde(default)]
|
||||
pub stream: bool,
|
||||
|
||||
/// Temperature for sampling
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub temperature: Option<f32>,
|
||||
|
||||
/// Tool choice behavior
|
||||
#[serde(default)]
|
||||
pub tool_choice: ToolChoice,
|
||||
|
||||
/// Available tools
|
||||
#[serde(default)]
|
||||
pub tools: Vec<ResponseTool>,
|
||||
|
||||
/// Number of top logprobs to return
|
||||
#[serde(default)]
|
||||
pub top_logprobs: u32,
|
||||
|
||||
/// Top-p sampling parameter
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub top_p: Option<f32>,
|
||||
|
||||
/// Truncation behavior
|
||||
#[serde(default)]
|
||||
pub truncation: Truncation,
|
||||
|
||||
/// User identifier
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub user: Option<String>,
|
||||
|
||||
// ============= SGLang Extensions =============
|
||||
/// Request ID
|
||||
#[serde(default = "generate_request_id")]
|
||||
pub request_id: String,
|
||||
|
||||
/// Request priority
|
||||
#[serde(default)]
|
||||
pub priority: i32,
|
||||
|
||||
/// Frequency penalty
|
||||
#[serde(default)]
|
||||
pub frequency_penalty: f32,
|
||||
|
||||
/// Presence penalty
|
||||
#[serde(default)]
|
||||
pub presence_penalty: f32,
|
||||
|
||||
/// Stop sequences
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub stop: Option<StringOrArray>,
|
||||
|
||||
/// Top-k sampling parameter
|
||||
#[serde(default = "default_top_k")]
|
||||
pub top_k: i32,
|
||||
|
||||
/// Min-p sampling parameter
|
||||
#[serde(default)]
|
||||
pub min_p: f32,
|
||||
|
||||
/// Repetition penalty
|
||||
#[serde(default = "default_repetition_penalty")]
|
||||
pub repetition_penalty: f32,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum ResponseInput {
|
||||
Text(String),
|
||||
Items(Vec<ResponseInputOutputItem>),
|
||||
}
|
||||
|
||||
fn default_top_k() -> i32 {
|
||||
-1
|
||||
}
|
||||
|
||||
fn default_repetition_penalty() -> f32 {
|
||||
1.0
|
||||
}
|
||||
|
||||
fn default_true() -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
impl ResponsesRequest {
|
||||
/// Default sampling parameters
|
||||
const DEFAULT_TEMPERATURE: f32 = 0.7;
|
||||
const DEFAULT_TOP_P: f32 = 1.0;
|
||||
|
||||
/// Convert to sampling parameters for generation
|
||||
pub fn to_sampling_params(
|
||||
&self,
|
||||
default_max_tokens: u32,
|
||||
default_params: Option<HashMap<String, serde_json::Value>>,
|
||||
) -> HashMap<String, serde_json::Value> {
|
||||
let mut params = HashMap::new();
|
||||
|
||||
// Use max_output_tokens if available
|
||||
let max_tokens = if let Some(max_output) = self.max_output_tokens {
|
||||
std::cmp::min(max_output, default_max_tokens)
|
||||
} else {
|
||||
default_max_tokens
|
||||
};
|
||||
|
||||
// Avoid exceeding context length by minus 1 token
|
||||
let max_tokens = max_tokens.saturating_sub(1);
|
||||
|
||||
// Temperature
|
||||
let temperature = self.temperature.unwrap_or_else(|| {
|
||||
default_params
|
||||
.as_ref()
|
||||
.and_then(|p| p.get("temperature"))
|
||||
.and_then(|v| v.as_f64())
|
||||
.map(|v| v as f32)
|
||||
.unwrap_or(Self::DEFAULT_TEMPERATURE)
|
||||
});
|
||||
|
||||
// Top-p
|
||||
let top_p = self.top_p.unwrap_or_else(|| {
|
||||
default_params
|
||||
.as_ref()
|
||||
.and_then(|p| p.get("top_p"))
|
||||
.and_then(|v| v.as_f64())
|
||||
.map(|v| v as f32)
|
||||
.unwrap_or(Self::DEFAULT_TOP_P)
|
||||
});
|
||||
|
||||
params.insert(
|
||||
"max_new_tokens".to_string(),
|
||||
serde_json::Value::Number(serde_json::Number::from(max_tokens)),
|
||||
);
|
||||
params.insert(
|
||||
"temperature".to_string(),
|
||||
serde_json::Value::Number(serde_json::Number::from_f64(temperature as f64).unwrap()),
|
||||
);
|
||||
params.insert(
|
||||
"top_p".to_string(),
|
||||
serde_json::Value::Number(serde_json::Number::from_f64(top_p as f64).unwrap()),
|
||||
);
|
||||
params.insert(
|
||||
"frequency_penalty".to_string(),
|
||||
serde_json::Value::Number(
|
||||
serde_json::Number::from_f64(self.frequency_penalty as f64).unwrap(),
|
||||
),
|
||||
);
|
||||
params.insert(
|
||||
"presence_penalty".to_string(),
|
||||
serde_json::Value::Number(
|
||||
serde_json::Number::from_f64(self.presence_penalty as f64).unwrap(),
|
||||
),
|
||||
);
|
||||
params.insert(
|
||||
"top_k".to_string(),
|
||||
serde_json::Value::Number(serde_json::Number::from(self.top_k)),
|
||||
);
|
||||
params.insert(
|
||||
"min_p".to_string(),
|
||||
serde_json::Value::Number(serde_json::Number::from_f64(self.min_p as f64).unwrap()),
|
||||
);
|
||||
params.insert(
|
||||
"repetition_penalty".to_string(),
|
||||
serde_json::Value::Number(
|
||||
serde_json::Number::from_f64(self.repetition_penalty as f64).unwrap(),
|
||||
),
|
||||
);
|
||||
|
||||
if let Some(ref stop) = self.stop {
|
||||
match serde_json::to_value(stop) {
|
||||
Ok(value) => params.insert("stop".to_string(), value),
|
||||
Err(_) => params.insert("stop".to_string(), serde_json::Value::Null),
|
||||
};
|
||||
}
|
||||
|
||||
// Apply any additional default parameters
|
||||
if let Some(default_params) = default_params {
|
||||
for (key, value) in default_params {
|
||||
params.entry(key).or_insert(value);
|
||||
}
|
||||
}
|
||||
|
||||
params
|
||||
}
|
||||
}
|
||||
|
||||
impl GenerationRequest for ResponsesRequest {
|
||||
fn is_stream(&self) -> bool {
|
||||
self.stream
|
||||
}
|
||||
|
||||
fn get_model(&self) -> Option<&str> {
|
||||
self.model.as_deref()
|
||||
}
|
||||
|
||||
fn extract_text_for_routing(&self) -> String {
|
||||
match &self.input {
|
||||
ResponseInput::Text(text) => text.clone(),
|
||||
ResponseInput::Items(items) => items
|
||||
.iter()
|
||||
.filter_map(|item| match item {
|
||||
ResponseInputOutputItem::Message { content, .. } => {
|
||||
let texts: Vec<String> = content
|
||||
.iter()
|
||||
.map(|part| match part {
|
||||
ResponseContentPart::OutputText { text, .. } => text.clone(),
|
||||
})
|
||||
.collect();
|
||||
if texts.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(texts.join(" "))
|
||||
}
|
||||
}
|
||||
ResponseInputOutputItem::Reasoning { content, .. } => {
|
||||
let texts: Vec<String> = content
|
||||
.iter()
|
||||
.map(|part| match part {
|
||||
ResponseReasoningContent::ReasoningText { text } => text.clone(),
|
||||
})
|
||||
.collect();
|
||||
if texts.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(texts.join(" "))
|
||||
}
|
||||
}
|
||||
ResponseInputOutputItem::FunctionToolCall { arguments, .. } => {
|
||||
Some(arguments.clone())
|
||||
}
|
||||
})
|
||||
.collect::<Vec<String>>()
|
||||
.join(" "),
|
||||
}
|
||||
}
|
||||
}
|
||||
280
sgl-router/src/protocols/openai/responses/response.rs
Normal file
280
sgl-router/src/protocols/openai/responses/response.rs
Normal file
@@ -0,0 +1,280 @@
|
||||
// Responses API response types
|
||||
|
||||
use crate::protocols::openai::responses::request::ResponsesRequest;
|
||||
use crate::protocols::openai::responses::types::*;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
|
||||
fn generate_response_id() -> String {
|
||||
format!("resp_{}", uuid::Uuid::new_v4().simple())
|
||||
}
|
||||
|
||||
fn current_timestamp() -> i64 {
|
||||
std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap_or_else(|_| std::time::Duration::from_secs(0))
|
||||
.as_secs() as i64
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct ResponsesResponse {
|
||||
/// Response ID
|
||||
#[serde(default = "generate_response_id")]
|
||||
pub id: String,
|
||||
|
||||
/// Object type
|
||||
#[serde(default = "default_object_type")]
|
||||
pub object: String,
|
||||
|
||||
/// Creation timestamp
|
||||
#[serde(default = "current_timestamp")]
|
||||
pub created_at: i64,
|
||||
|
||||
/// Model name
|
||||
pub model: String,
|
||||
|
||||
/// Output items
|
||||
#[serde(default)]
|
||||
pub output: Vec<ResponseOutputItem>,
|
||||
|
||||
/// Response status
|
||||
pub status: ResponseStatus,
|
||||
|
||||
/// Usage information
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub usage: Option<UsageInfo>,
|
||||
|
||||
/// Whether parallel tool calls are enabled
|
||||
#[serde(default = "default_true")]
|
||||
pub parallel_tool_calls: bool,
|
||||
|
||||
/// Tool choice setting
|
||||
#[serde(default = "default_tool_choice")]
|
||||
pub tool_choice: String,
|
||||
|
||||
/// Available tools
|
||||
#[serde(default)]
|
||||
pub tools: Vec<ResponseTool>,
|
||||
}
|
||||
|
||||
fn default_object_type() -> String {
|
||||
"response".to_string()
|
||||
}
|
||||
|
||||
fn default_true() -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
fn default_tool_choice() -> String {
|
||||
"auto".to_string()
|
||||
}
|
||||
|
||||
impl ResponsesResponse {
|
||||
/// Create a response from a request
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn from_request(
|
||||
request: &ResponsesRequest,
|
||||
_sampling_params: &HashMap<String, serde_json::Value>,
|
||||
model_name: String,
|
||||
created_time: i64,
|
||||
output: Vec<ResponseOutputItem>,
|
||||
status: ResponseStatus,
|
||||
usage: Option<UsageInfo>,
|
||||
) -> Self {
|
||||
Self {
|
||||
id: request.request_id.clone(),
|
||||
object: "response".to_string(),
|
||||
created_at: created_time,
|
||||
model: model_name,
|
||||
output,
|
||||
status,
|
||||
usage,
|
||||
parallel_tool_calls: request.parallel_tool_calls,
|
||||
tool_choice: match request.tool_choice {
|
||||
ToolChoice::Auto => "auto".to_string(),
|
||||
ToolChoice::Required => "required".to_string(),
|
||||
ToolChoice::None => "none".to_string(),
|
||||
},
|
||||
tools: request.tools.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new response with default values
|
||||
pub fn new(request_id: String, model: String, status: ResponseStatus) -> Self {
|
||||
Self {
|
||||
id: request_id,
|
||||
object: "response".to_string(),
|
||||
created_at: current_timestamp(),
|
||||
model,
|
||||
output: Vec::new(),
|
||||
status,
|
||||
usage: None,
|
||||
parallel_tool_calls: true,
|
||||
tool_choice: "auto".to_string(),
|
||||
tools: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Add an output item to the response
|
||||
pub fn add_output(&mut self, item: ResponseOutputItem) {
|
||||
self.output.push(item);
|
||||
}
|
||||
|
||||
/// Set the usage information
|
||||
pub fn set_usage(&mut self, usage: UsageInfo) {
|
||||
self.usage = Some(usage);
|
||||
}
|
||||
|
||||
/// Update the status
|
||||
pub fn set_status(&mut self, status: ResponseStatus) {
|
||||
self.status = status;
|
||||
}
|
||||
|
||||
/// Check if the response is complete
|
||||
pub fn is_complete(&self) -> bool {
|
||||
matches!(self.status, ResponseStatus::Completed)
|
||||
}
|
||||
|
||||
/// Check if the response is in progress
|
||||
pub fn is_in_progress(&self) -> bool {
|
||||
matches!(self.status, ResponseStatus::InProgress)
|
||||
}
|
||||
|
||||
/// Check if the response failed
|
||||
pub fn is_failed(&self) -> bool {
|
||||
matches!(self.status, ResponseStatus::Failed)
|
||||
}
|
||||
|
||||
/// Check if the response was cancelled
|
||||
pub fn is_cancelled(&self) -> bool {
|
||||
matches!(self.status, ResponseStatus::Cancelled)
|
||||
}
|
||||
|
||||
/// Check if the response is queued
|
||||
pub fn is_queued(&self) -> bool {
|
||||
matches!(self.status, ResponseStatus::Queued)
|
||||
}
|
||||
|
||||
/// Convert usage to OpenAI Responses API format
|
||||
pub fn usage_in_response_format(
|
||||
&self,
|
||||
) -> Option<crate::protocols::openai::responses::types::ResponseUsage> {
|
||||
self.usage.as_ref().map(|usage| usage.to_response_usage())
|
||||
}
|
||||
|
||||
/// Get the response as a JSON value with usage in response format
|
||||
pub fn to_response_format(&self) -> serde_json::Value {
|
||||
let mut response = serde_json::to_value(self).unwrap_or(serde_json::Value::Null);
|
||||
|
||||
// Convert usage to response format if present
|
||||
if let Some(usage) = &self.usage {
|
||||
if let Ok(usage_value) = serde_json::to_value(usage.to_response_usage()) {
|
||||
response["usage"] = usage_value;
|
||||
}
|
||||
}
|
||||
|
||||
response
|
||||
}
|
||||
}
|
||||
|
||||
// ============= Helper Functions =============
|
||||
|
||||
impl ResponseOutputItem {
|
||||
/// Create a new message output item
|
||||
pub fn new_message(
|
||||
id: String,
|
||||
role: String,
|
||||
content: Vec<ResponseContentPart>,
|
||||
status: String,
|
||||
) -> Self {
|
||||
Self::Message {
|
||||
id,
|
||||
role,
|
||||
content,
|
||||
status,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new reasoning output item
|
||||
pub fn new_reasoning(
|
||||
id: String,
|
||||
summary: Vec<String>,
|
||||
content: Vec<ResponseReasoningContent>,
|
||||
status: Option<String>,
|
||||
) -> Self {
|
||||
Self::Reasoning {
|
||||
id,
|
||||
summary,
|
||||
content,
|
||||
status,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new function tool call output item
|
||||
pub fn new_function_tool_call(
|
||||
id: String,
|
||||
name: String,
|
||||
arguments: String,
|
||||
output: Option<String>,
|
||||
status: String,
|
||||
) -> Self {
|
||||
Self::FunctionToolCall {
|
||||
id,
|
||||
name,
|
||||
arguments,
|
||||
output,
|
||||
status,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ResponseContentPart {
|
||||
/// Create a new text content part
|
||||
pub fn new_text(
|
||||
text: String,
|
||||
annotations: Vec<String>,
|
||||
logprobs: Option<crate::protocols::openai::common::ChatLogProbs>,
|
||||
) -> Self {
|
||||
Self::OutputText {
|
||||
text,
|
||||
annotations,
|
||||
logprobs,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ResponseReasoningContent {
|
||||
/// Create a new reasoning text content
|
||||
pub fn new_reasoning_text(text: String) -> Self {
|
||||
Self::ReasoningText { text }
|
||||
}
|
||||
}
|
||||
|
||||
impl UsageInfo {
|
||||
/// Create a new usage info with token counts
|
||||
pub fn new(prompt_tokens: u32, completion_tokens: u32, reasoning_tokens: Option<u32>) -> Self {
|
||||
Self {
|
||||
prompt_tokens,
|
||||
completion_tokens,
|
||||
total_tokens: prompt_tokens + completion_tokens,
|
||||
reasoning_tokens,
|
||||
prompt_tokens_details: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create usage info with cached token details
|
||||
pub fn new_with_cached(
|
||||
prompt_tokens: u32,
|
||||
completion_tokens: u32,
|
||||
reasoning_tokens: Option<u32>,
|
||||
cached_tokens: u32,
|
||||
) -> Self {
|
||||
Self {
|
||||
prompt_tokens,
|
||||
completion_tokens,
|
||||
total_tokens: prompt_tokens + completion_tokens,
|
||||
reasoning_tokens,
|
||||
prompt_tokens_details: Some(PromptTokenUsageInfo { cached_tokens }),
|
||||
}
|
||||
}
|
||||
}
|
||||
296
sgl-router/src/protocols/openai/responses/types.rs
Normal file
296
sgl-router/src/protocols/openai/responses/types.rs
Normal file
@@ -0,0 +1,296 @@
|
||||
// Supporting types for Responses API
|
||||
|
||||
use crate::protocols::openai::common::ChatLogProbs;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
// ============= Tool Definitions =============
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct ResponseTool {
|
||||
#[serde(rename = "type")]
|
||||
pub r#type: ResponseToolType,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum ResponseToolType {
|
||||
WebSearchPreview,
|
||||
CodeInterpreter,
|
||||
}
|
||||
|
||||
// ============= Reasoning Configuration =============
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct ResponseReasoningParam {
|
||||
#[serde(default = "default_reasoning_effort")]
|
||||
pub effort: Option<ReasoningEffort>,
|
||||
}
|
||||
|
||||
fn default_reasoning_effort() -> Option<ReasoningEffort> {
|
||||
Some(ReasoningEffort::Medium)
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum ReasoningEffort {
|
||||
Low,
|
||||
Medium,
|
||||
High,
|
||||
}
|
||||
|
||||
// ============= Input/Output Items =============
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
#[serde(tag = "type")]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum ResponseInputOutputItem {
|
||||
#[serde(rename = "message")]
|
||||
Message {
|
||||
id: String,
|
||||
role: String,
|
||||
content: Vec<ResponseContentPart>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
status: Option<String>,
|
||||
},
|
||||
#[serde(rename = "reasoning")]
|
||||
Reasoning {
|
||||
id: String,
|
||||
#[serde(skip_serializing_if = "Vec::is_empty")]
|
||||
summary: Vec<String>,
|
||||
content: Vec<ResponseReasoningContent>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
status: Option<String>,
|
||||
},
|
||||
#[serde(rename = "function_tool_call")]
|
||||
FunctionToolCall {
|
||||
id: String,
|
||||
name: String,
|
||||
arguments: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
output: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
status: Option<String>,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
#[serde(tag = "type")]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum ResponseContentPart {
|
||||
#[serde(rename = "output_text")]
|
||||
OutputText {
|
||||
text: String,
|
||||
#[serde(skip_serializing_if = "Vec::is_empty")]
|
||||
annotations: Vec<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
logprobs: Option<ChatLogProbs>,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
#[serde(tag = "type")]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum ResponseReasoningContent {
|
||||
#[serde(rename = "reasoning_text")]
|
||||
ReasoningText { text: String },
|
||||
}
|
||||
|
||||
// ============= Output Items for Response =============
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
#[serde(tag = "type")]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum ResponseOutputItem {
|
||||
#[serde(rename = "message")]
|
||||
Message {
|
||||
id: String,
|
||||
role: String,
|
||||
content: Vec<ResponseContentPart>,
|
||||
status: String,
|
||||
},
|
||||
#[serde(rename = "reasoning")]
|
||||
Reasoning {
|
||||
id: String,
|
||||
#[serde(skip_serializing_if = "Vec::is_empty")]
|
||||
summary: Vec<String>,
|
||||
content: Vec<ResponseReasoningContent>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
status: Option<String>,
|
||||
},
|
||||
#[serde(rename = "function_tool_call")]
|
||||
FunctionToolCall {
|
||||
id: String,
|
||||
name: String,
|
||||
arguments: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
output: Option<String>,
|
||||
status: String,
|
||||
},
|
||||
}
|
||||
|
||||
// ============= Service Tier =============
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum ServiceTier {
|
||||
Auto,
|
||||
Default,
|
||||
Flex,
|
||||
Scale,
|
||||
Priority,
|
||||
}
|
||||
|
||||
impl Default for ServiceTier {
|
||||
fn default() -> Self {
|
||||
Self::Auto
|
||||
}
|
||||
}
|
||||
|
||||
// ============= Tool Choice =============
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum ToolChoice {
|
||||
Auto,
|
||||
Required,
|
||||
None,
|
||||
}
|
||||
|
||||
impl Default for ToolChoice {
|
||||
fn default() -> Self {
|
||||
Self::Auto
|
||||
}
|
||||
}
|
||||
|
||||
// ============= Truncation =============
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum Truncation {
|
||||
Auto,
|
||||
Disabled,
|
||||
}
|
||||
|
||||
impl Default for Truncation {
|
||||
fn default() -> Self {
|
||||
Self::Disabled
|
||||
}
|
||||
}
|
||||
|
||||
// ============= Response Status =============
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum ResponseStatus {
|
||||
Queued,
|
||||
InProgress,
|
||||
Completed,
|
||||
Failed,
|
||||
Cancelled,
|
||||
}
|
||||
|
||||
// ============= Include Fields =============
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum IncludeField {
|
||||
#[serde(rename = "code_interpreter_call.outputs")]
|
||||
CodeInterpreterCallOutputs,
|
||||
#[serde(rename = "computer_call_output.output.image_url")]
|
||||
ComputerCallOutputImageUrl,
|
||||
#[serde(rename = "file_search_call.results")]
|
||||
FileSearchCallResults,
|
||||
#[serde(rename = "message.input_image.image_url")]
|
||||
MessageInputImageUrl,
|
||||
#[serde(rename = "message.output_text.logprobs")]
|
||||
MessageOutputTextLogprobs,
|
||||
#[serde(rename = "reasoning.encrypted_content")]
|
||||
ReasoningEncryptedContent,
|
||||
}
|
||||
|
||||
// ============= Usage Info =============
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct UsageInfo {
|
||||
pub prompt_tokens: u32,
|
||||
pub completion_tokens: u32,
|
||||
pub total_tokens: u32,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub reasoning_tokens: Option<u32>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub prompt_tokens_details: Option<PromptTokenUsageInfo>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct PromptTokenUsageInfo {
|
||||
pub cached_tokens: u32,
|
||||
}
|
||||
|
||||
// ============= Response Usage Format =============
|
||||
|
||||
/// OpenAI Responses API usage format (different from standard UsageInfo)
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct ResponseUsage {
|
||||
pub input_tokens: u32,
|
||||
pub output_tokens: u32,
|
||||
pub total_tokens: u32,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub input_tokens_details: Option<InputTokensDetails>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub output_tokens_details: Option<OutputTokensDetails>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct InputTokensDetails {
|
||||
pub cached_tokens: u32,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct OutputTokensDetails {
|
||||
pub reasoning_tokens: u32,
|
||||
}
|
||||
|
||||
impl UsageInfo {
|
||||
/// Convert to OpenAI Responses API format
|
||||
pub fn to_response_usage(&self) -> ResponseUsage {
|
||||
ResponseUsage {
|
||||
input_tokens: self.prompt_tokens,
|
||||
output_tokens: self.completion_tokens,
|
||||
total_tokens: self.total_tokens,
|
||||
input_tokens_details: self.prompt_tokens_details.as_ref().map(|details| {
|
||||
InputTokensDetails {
|
||||
cached_tokens: details.cached_tokens,
|
||||
}
|
||||
}),
|
||||
output_tokens_details: self.reasoning_tokens.map(|tokens| OutputTokensDetails {
|
||||
reasoning_tokens: tokens,
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<UsageInfo> for ResponseUsage {
|
||||
fn from(usage: UsageInfo) -> Self {
|
||||
usage.to_response_usage()
|
||||
}
|
||||
}
|
||||
|
||||
impl ResponseUsage {
|
||||
/// Convert back to standard UsageInfo format
|
||||
pub fn to_usage_info(&self) -> UsageInfo {
|
||||
UsageInfo {
|
||||
prompt_tokens: self.input_tokens,
|
||||
completion_tokens: self.output_tokens,
|
||||
total_tokens: self.total_tokens,
|
||||
reasoning_tokens: self
|
||||
.output_tokens_details
|
||||
.as_ref()
|
||||
.map(|details| details.reasoning_tokens),
|
||||
prompt_tokens_details: self.input_tokens_details.as_ref().map(|details| {
|
||||
PromptTokenUsageInfo {
|
||||
cached_tokens: details.cached_tokens,
|
||||
}
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
208
sgl-router/tests/responses_api_test.rs
Normal file
208
sgl-router/tests/responses_api_test.rs
Normal file
@@ -0,0 +1,208 @@
|
||||
// Integration test for Responses API
|
||||
|
||||
use sglang_router_rs::protocols::common::GenerationRequest;
|
||||
use sglang_router_rs::protocols::openai::responses::request::ResponseInput;
|
||||
use sglang_router_rs::protocols::openai::responses::*;
|
||||
|
||||
#[test]
|
||||
fn test_responses_request_creation() {
|
||||
let request = ResponsesRequest {
|
||||
background: false,
|
||||
include: None,
|
||||
input: ResponseInput::Text("Hello, world!".to_string()),
|
||||
instructions: Some("Be helpful".to_string()),
|
||||
max_output_tokens: Some(100),
|
||||
max_tool_calls: None,
|
||||
metadata: None,
|
||||
model: Some("test-model".to_string()),
|
||||
parallel_tool_calls: true,
|
||||
previous_response_id: None,
|
||||
reasoning: Some(ResponseReasoningParam {
|
||||
effort: Some(ReasoningEffort::Medium),
|
||||
}),
|
||||
service_tier: ServiceTier::Auto,
|
||||
store: true,
|
||||
stream: false,
|
||||
temperature: Some(0.7),
|
||||
tool_choice: ToolChoice::Auto,
|
||||
tools: vec![ResponseTool {
|
||||
r#type: ResponseToolType::WebSearchPreview,
|
||||
}],
|
||||
top_logprobs: 5,
|
||||
top_p: Some(0.9),
|
||||
truncation: Truncation::Disabled,
|
||||
user: Some("test-user".to_string()),
|
||||
request_id: "resp_test123".to_string(),
|
||||
priority: 0,
|
||||
frequency_penalty: 0.0,
|
||||
presence_penalty: 0.0,
|
||||
stop: None,
|
||||
top_k: -1,
|
||||
min_p: 0.0,
|
||||
repetition_penalty: 1.0,
|
||||
};
|
||||
|
||||
// Test GenerationRequest trait implementation
|
||||
assert!(!request.is_stream());
|
||||
assert_eq!(request.get_model(), Some("test-model"));
|
||||
let routing_text = request.extract_text_for_routing();
|
||||
assert_eq!(routing_text, "Hello, world!");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sampling_params_conversion() {
|
||||
let request = ResponsesRequest {
|
||||
background: false,
|
||||
include: None,
|
||||
input: ResponseInput::Text("Test".to_string()),
|
||||
instructions: None,
|
||||
max_output_tokens: Some(50),
|
||||
max_tool_calls: None,
|
||||
metadata: None,
|
||||
model: Some("test-model".to_string()),
|
||||
parallel_tool_calls: true, // Use default true
|
||||
previous_response_id: None,
|
||||
reasoning: None,
|
||||
service_tier: ServiceTier::Auto,
|
||||
store: true, // Use default true
|
||||
stream: false,
|
||||
temperature: Some(0.8),
|
||||
tool_choice: ToolChoice::Auto,
|
||||
tools: vec![],
|
||||
top_logprobs: 0, // Use default 0
|
||||
top_p: Some(0.95),
|
||||
truncation: Truncation::Auto,
|
||||
user: None,
|
||||
request_id: "resp_test456".to_string(),
|
||||
priority: 0,
|
||||
frequency_penalty: 0.1,
|
||||
presence_penalty: 0.2,
|
||||
stop: None,
|
||||
top_k: 10,
|
||||
min_p: 0.05,
|
||||
repetition_penalty: 1.1,
|
||||
};
|
||||
|
||||
let params = request.to_sampling_params(1000, None);
|
||||
|
||||
// Check that parameters are converted correctly
|
||||
assert!(params.contains_key("temperature"));
|
||||
assert!(params.contains_key("top_p"));
|
||||
assert!(params.contains_key("frequency_penalty"));
|
||||
assert!(params.contains_key("max_new_tokens"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_responses_response_creation() {
|
||||
let response = ResponsesResponse::new(
|
||||
"resp_test789".to_string(),
|
||||
"test-model".to_string(),
|
||||
ResponseStatus::Completed,
|
||||
);
|
||||
|
||||
assert_eq!(response.id, "resp_test789");
|
||||
assert_eq!(response.model, "test-model");
|
||||
assert!(response.is_complete());
|
||||
assert!(!response.is_in_progress());
|
||||
assert!(!response.is_failed());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_usage_conversion() {
|
||||
let usage_info = UsageInfo::new_with_cached(15, 25, Some(8), 3);
|
||||
let response_usage = usage_info.to_response_usage();
|
||||
|
||||
assert_eq!(response_usage.input_tokens, 15);
|
||||
assert_eq!(response_usage.output_tokens, 25);
|
||||
assert_eq!(response_usage.total_tokens, 40);
|
||||
|
||||
// Check details are converted correctly
|
||||
assert!(response_usage.input_tokens_details.is_some());
|
||||
assert_eq!(
|
||||
response_usage
|
||||
.input_tokens_details
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.cached_tokens,
|
||||
3
|
||||
);
|
||||
|
||||
assert!(response_usage.output_tokens_details.is_some());
|
||||
assert_eq!(
|
||||
response_usage
|
||||
.output_tokens_details
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.reasoning_tokens,
|
||||
8
|
||||
);
|
||||
|
||||
// Test reverse conversion
|
||||
let back_to_usage = response_usage.to_usage_info();
|
||||
assert_eq!(back_to_usage.prompt_tokens, 15);
|
||||
assert_eq!(back_to_usage.completion_tokens, 25);
|
||||
assert_eq!(back_to_usage.reasoning_tokens, Some(8));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_reasoning_param_default() {
|
||||
let param = ResponseReasoningParam {
|
||||
effort: Some(ReasoningEffort::Medium),
|
||||
};
|
||||
|
||||
// Test JSON serialization/deserialization preserves default
|
||||
let json = serde_json::to_string(¶m).unwrap();
|
||||
let parsed: ResponseReasoningParam = serde_json::from_str(&json).unwrap();
|
||||
|
||||
assert!(matches!(parsed.effort, Some(ReasoningEffort::Medium)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_json_serialization() {
|
||||
let request = ResponsesRequest {
|
||||
background: true,
|
||||
include: None,
|
||||
input: ResponseInput::Text("Test input".to_string()),
|
||||
instructions: Some("Test instructions".to_string()),
|
||||
max_output_tokens: Some(200),
|
||||
max_tool_calls: Some(5),
|
||||
metadata: None,
|
||||
model: Some("gpt-4".to_string()),
|
||||
parallel_tool_calls: false,
|
||||
previous_response_id: None,
|
||||
reasoning: Some(ResponseReasoningParam {
|
||||
effort: Some(ReasoningEffort::High),
|
||||
}),
|
||||
service_tier: ServiceTier::Priority,
|
||||
store: false,
|
||||
stream: true,
|
||||
temperature: Some(0.9),
|
||||
tool_choice: ToolChoice::Required,
|
||||
tools: vec![ResponseTool {
|
||||
r#type: ResponseToolType::CodeInterpreter,
|
||||
}],
|
||||
top_logprobs: 10,
|
||||
top_p: Some(0.8),
|
||||
truncation: Truncation::Auto,
|
||||
user: Some("test_user".to_string()),
|
||||
request_id: "resp_comprehensive_test".to_string(),
|
||||
priority: 1,
|
||||
frequency_penalty: 0.3,
|
||||
presence_penalty: 0.4,
|
||||
stop: None,
|
||||
top_k: 50,
|
||||
min_p: 0.1,
|
||||
repetition_penalty: 1.2,
|
||||
};
|
||||
|
||||
// Test that everything can be serialized to JSON and back
|
||||
let json = serde_json::to_string(&request).expect("Serialization should work");
|
||||
let parsed: ResponsesRequest =
|
||||
serde_json::from_str(&json).expect("Deserialization should work");
|
||||
|
||||
assert_eq!(parsed.request_id, "resp_comprehensive_test");
|
||||
assert_eq!(parsed.model, Some("gpt-4".to_string()));
|
||||
assert!(parsed.background);
|
||||
assert!(parsed.stream);
|
||||
assert_eq!(parsed.tools.len(), 1);
|
||||
}
|
||||
Reference in New Issue
Block a user