[router] responses api POST and GET with local storage (#10581)
Co-authored-by: key4ng <rukeyang@gmail.com>
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
use serde_json::{to_value, Map, Number, Value};
|
||||
use std::collections::HashMap;
|
||||
|
||||
// # Protocol Specifications
|
||||
@@ -350,7 +350,7 @@ pub struct ChatCompletionRequest {
|
||||
|
||||
/// Session parameters for continual prompting
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub session_params: Option<HashMap<String, serde_json::Value>>,
|
||||
pub session_params: Option<HashMap<String, Value>>,
|
||||
|
||||
/// Separate reasoning content from final answer (O1-style models)
|
||||
#[serde(default = "default_true")]
|
||||
@@ -362,7 +362,7 @@ pub struct ChatCompletionRequest {
|
||||
|
||||
/// Chat template kwargs
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub chat_template_kwargs: Option<HashMap<String, serde_json::Value>>,
|
||||
pub chat_template_kwargs: Option<HashMap<String, Value>>,
|
||||
|
||||
/// Return model hidden states
|
||||
#[serde(default)]
|
||||
@@ -447,7 +447,7 @@ pub struct ChatChoice {
|
||||
pub finish_reason: Option<String>, // "stop", "length", "tool_calls", "content_filter", "function_call"
|
||||
/// Information about which stop condition was matched
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub matched_stop: Option<serde_json::Value>, // Can be string or integer
|
||||
pub matched_stop: Option<Value>, // Can be string or integer
|
||||
/// Hidden states from the model (SGLang extension)
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub hidden_states: Option<Vec<f32>>,
|
||||
@@ -606,7 +606,7 @@ pub struct CompletionRequest {
|
||||
|
||||
/// Session parameters for continual prompting
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub session_params: Option<HashMap<String, serde_json::Value>>,
|
||||
pub session_params: Option<HashMap<String, Value>>,
|
||||
|
||||
/// Return model hidden states
|
||||
#[serde(default)]
|
||||
@@ -618,7 +618,7 @@ pub struct CompletionRequest {
|
||||
|
||||
/// Additional fields including bootstrap info for PD routing
|
||||
#[serde(flatten)]
|
||||
pub other: serde_json::Map<String, serde_json::Value>,
|
||||
pub other: Map<String, Value>,
|
||||
}
|
||||
|
||||
impl GenerationRequest for CompletionRequest {
|
||||
@@ -662,7 +662,7 @@ pub struct CompletionChoice {
|
||||
pub finish_reason: Option<String>, // "stop", "length", "content_filter", etc.
|
||||
/// Information about which stop condition was matched
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub matched_stop: Option<serde_json::Value>, // Can be string or integer
|
||||
pub matched_stop: Option<Value>, // Can be string or integer
|
||||
/// Hidden states from the model (SGLang extension)
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub hidden_states: Option<Vec<f32>>,
|
||||
@@ -776,6 +776,10 @@ pub enum ResponseContentPart {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
logprobs: Option<ChatLogProbs>,
|
||||
},
|
||||
#[serde(rename = "input_text")]
|
||||
InputText { text: String },
|
||||
#[serde(other)]
|
||||
Unknown,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
@@ -864,6 +868,29 @@ pub enum ResponseStatus {
|
||||
Cancelled,
|
||||
}
|
||||
|
||||
// ============= Reasoning Info =============
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct ReasoningInfo {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub effort: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub summary: Option<String>,
|
||||
}
|
||||
|
||||
// ============= Text Format =============
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct ResponseTextFormat {
|
||||
pub format: TextFormatType,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct TextFormatType {
|
||||
#[serde(rename = "type")]
|
||||
pub format_type: String,
|
||||
}
|
||||
|
||||
// ============= Include Fields =============
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
@@ -915,6 +942,13 @@ pub struct ResponseUsage {
|
||||
pub output_tokens_details: Option<OutputTokensDetails>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum ResponsesUsage {
|
||||
Classic(UsageInfo),
|
||||
Modern(ResponseUsage),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct InputTokensDetails {
|
||||
pub cached_tokens: u32,
|
||||
@@ -970,6 +1004,34 @@ impl ResponseUsage {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default, Deserialize, Serialize)]
|
||||
pub struct ResponsesGetParams {
|
||||
#[serde(default)]
|
||||
pub include: Vec<String>,
|
||||
#[serde(default)]
|
||||
pub include_obfuscation: Option<bool>,
|
||||
#[serde(default)]
|
||||
pub starting_after: Option<i64>,
|
||||
#[serde(default)]
|
||||
pub stream: Option<bool>,
|
||||
}
|
||||
|
||||
impl ResponsesUsage {
|
||||
pub fn to_response_usage(&self) -> ResponseUsage {
|
||||
match self {
|
||||
ResponsesUsage::Classic(usage) => usage.to_response_usage(),
|
||||
ResponsesUsage::Modern(usage) => usage.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn to_usage_info(&self) -> UsageInfo {
|
||||
match self {
|
||||
ResponsesUsage::Classic(usage) => usage.clone(),
|
||||
ResponsesUsage::Modern(usage) => usage.to_usage_info(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn generate_request_id() -> String {
|
||||
format!("resp_{}", uuid::Uuid::new_v4().simple())
|
||||
}
|
||||
@@ -1002,7 +1064,7 @@ pub struct ResponsesRequest {
|
||||
|
||||
/// Additional metadata
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub metadata: Option<HashMap<String, serde_json::Value>>,
|
||||
pub metadata: Option<HashMap<String, Value>>,
|
||||
|
||||
/// Model to use (optional to match vLLM)
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
@@ -1109,6 +1171,42 @@ fn default_repetition_penalty() -> f32 {
|
||||
1.0
|
||||
}
|
||||
|
||||
impl Default for ResponsesRequest {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
background: false,
|
||||
include: None,
|
||||
input: ResponseInput::Text(String::new()),
|
||||
instructions: None,
|
||||
max_output_tokens: None,
|
||||
max_tool_calls: None,
|
||||
metadata: None,
|
||||
model: None,
|
||||
parallel_tool_calls: true,
|
||||
previous_response_id: None,
|
||||
reasoning: None,
|
||||
service_tier: ServiceTier::default(),
|
||||
store: true,
|
||||
stream: false,
|
||||
temperature: None,
|
||||
tool_choice: ToolChoice::default(),
|
||||
tools: Vec::new(),
|
||||
top_logprobs: 0,
|
||||
top_p: None,
|
||||
truncation: Truncation::default(),
|
||||
user: None,
|
||||
request_id: generate_request_id(),
|
||||
priority: 0,
|
||||
frequency_penalty: 0.0,
|
||||
presence_penalty: 0.0,
|
||||
stop: None,
|
||||
top_k: default_top_k(),
|
||||
min_p: 0.0,
|
||||
repetition_penalty: default_repetition_penalty(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ResponsesRequest {
|
||||
/// Default sampling parameters
|
||||
const DEFAULT_TEMPERATURE: f32 = 0.7;
|
||||
@@ -1118,8 +1216,8 @@ impl ResponsesRequest {
|
||||
pub fn to_sampling_params(
|
||||
&self,
|
||||
default_max_tokens: u32,
|
||||
default_params: Option<HashMap<String, serde_json::Value>>,
|
||||
) -> HashMap<String, serde_json::Value> {
|
||||
default_params: Option<HashMap<String, Value>>,
|
||||
) -> HashMap<String, Value> {
|
||||
let mut params = HashMap::new();
|
||||
|
||||
// Use max_output_tokens if available
|
||||
@@ -1154,47 +1252,38 @@ impl ResponsesRequest {
|
||||
|
||||
params.insert(
|
||||
"max_new_tokens".to_string(),
|
||||
serde_json::Value::Number(serde_json::Number::from(max_tokens)),
|
||||
Value::Number(Number::from(max_tokens)),
|
||||
);
|
||||
params.insert(
|
||||
"temperature".to_string(),
|
||||
serde_json::Value::Number(serde_json::Number::from_f64(temperature as f64).unwrap()),
|
||||
Value::Number(Number::from_f64(temperature as f64).unwrap()),
|
||||
);
|
||||
params.insert(
|
||||
"top_p".to_string(),
|
||||
serde_json::Value::Number(serde_json::Number::from_f64(top_p as f64).unwrap()),
|
||||
Value::Number(Number::from_f64(top_p as f64).unwrap()),
|
||||
);
|
||||
params.insert(
|
||||
"frequency_penalty".to_string(),
|
||||
serde_json::Value::Number(
|
||||
serde_json::Number::from_f64(self.frequency_penalty as f64).unwrap(),
|
||||
),
|
||||
Value::Number(Number::from_f64(self.frequency_penalty as f64).unwrap()),
|
||||
);
|
||||
params.insert(
|
||||
"presence_penalty".to_string(),
|
||||
serde_json::Value::Number(
|
||||
serde_json::Number::from_f64(self.presence_penalty as f64).unwrap(),
|
||||
),
|
||||
);
|
||||
params.insert(
|
||||
"top_k".to_string(),
|
||||
serde_json::Value::Number(serde_json::Number::from(self.top_k)),
|
||||
Value::Number(Number::from_f64(self.presence_penalty as f64).unwrap()),
|
||||
);
|
||||
params.insert("top_k".to_string(), Value::Number(Number::from(self.top_k)));
|
||||
params.insert(
|
||||
"min_p".to_string(),
|
||||
serde_json::Value::Number(serde_json::Number::from_f64(self.min_p as f64).unwrap()),
|
||||
Value::Number(Number::from_f64(self.min_p as f64).unwrap()),
|
||||
);
|
||||
params.insert(
|
||||
"repetition_penalty".to_string(),
|
||||
serde_json::Value::Number(
|
||||
serde_json::Number::from_f64(self.repetition_penalty as f64).unwrap(),
|
||||
),
|
||||
Value::Number(Number::from_f64(self.repetition_penalty as f64).unwrap()),
|
||||
);
|
||||
|
||||
if let Some(ref stop) = self.stop {
|
||||
match serde_json::to_value(stop) {
|
||||
match to_value(stop) {
|
||||
Ok(value) => params.insert("stop".to_string(), value),
|
||||
Err(_) => params.insert("stop".to_string(), serde_json::Value::Null),
|
||||
Err(_) => params.insert("stop".to_string(), Value::Null),
|
||||
};
|
||||
}
|
||||
|
||||
@@ -1227,8 +1316,10 @@ impl GenerationRequest for ResponsesRequest {
|
||||
ResponseInputOutputItem::Message { content, .. } => {
|
||||
let texts: Vec<String> = content
|
||||
.iter()
|
||||
.map(|part| match part {
|
||||
ResponseContentPart::OutputText { text, .. } => text.clone(),
|
||||
.filter_map(|part| match part {
|
||||
ResponseContentPart::OutputText { text, .. } => Some(text.clone()),
|
||||
ResponseContentPart::InputText { text } => Some(text.clone()),
|
||||
ResponseContentPart::Unknown => None,
|
||||
})
|
||||
.collect();
|
||||
if texts.is_empty() {
|
||||
@@ -1285,6 +1376,25 @@ pub struct ResponsesResponse {
|
||||
#[serde(default = "current_timestamp")]
|
||||
pub created_at: i64,
|
||||
|
||||
/// Response status
|
||||
pub status: ResponseStatus,
|
||||
|
||||
/// Error information if status is failed
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub error: Option<Value>,
|
||||
|
||||
/// Incomplete details if response was truncated
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub incomplete_details: Option<Value>,
|
||||
|
||||
/// System instructions used
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub instructions: Option<String>,
|
||||
|
||||
/// Max output tokens setting
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub max_output_tokens: Option<u32>,
|
||||
|
||||
/// Model name
|
||||
pub model: String,
|
||||
|
||||
@@ -1292,17 +1402,30 @@ pub struct ResponsesResponse {
|
||||
#[serde(default)]
|
||||
pub output: Vec<ResponseOutputItem>,
|
||||
|
||||
/// Response status
|
||||
pub status: ResponseStatus,
|
||||
|
||||
/// Usage information
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub usage: Option<UsageInfo>,
|
||||
|
||||
/// Whether parallel tool calls are enabled
|
||||
#[serde(default = "default_true")]
|
||||
pub parallel_tool_calls: bool,
|
||||
|
||||
/// Previous response ID if this is a continuation
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub previous_response_id: Option<String>,
|
||||
|
||||
/// Reasoning information
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub reasoning: Option<ReasoningInfo>,
|
||||
|
||||
/// Whether the response is stored
|
||||
#[serde(default = "default_true")]
|
||||
pub store: bool,
|
||||
|
||||
/// Temperature setting used
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub temperature: Option<f32>,
|
||||
|
||||
/// Text format settings
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub text: Option<ResponseTextFormat>,
|
||||
|
||||
/// Tool choice setting
|
||||
#[serde(default = "default_tool_choice")]
|
||||
pub tool_choice: String,
|
||||
@@ -1310,6 +1433,26 @@ pub struct ResponsesResponse {
|
||||
/// Available tools
|
||||
#[serde(default)]
|
||||
pub tools: Vec<ResponseTool>,
|
||||
|
||||
/// Top-p setting used
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub top_p: Option<f32>,
|
||||
|
||||
/// Truncation strategy used
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub truncation: Option<String>,
|
||||
|
||||
/// Usage information
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub usage: Option<ResponsesUsage>,
|
||||
|
||||
/// User identifier
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub user: Option<String>,
|
||||
|
||||
/// Additional metadata
|
||||
#[serde(default)]
|
||||
pub metadata: HashMap<String, Value>,
|
||||
}
|
||||
|
||||
fn default_object_type() -> String {
|
||||
@@ -1325,7 +1468,7 @@ impl ResponsesResponse {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn from_request(
|
||||
request: &ResponsesRequest,
|
||||
_sampling_params: &HashMap<String, serde_json::Value>,
|
||||
_sampling_params: &HashMap<String, Value>,
|
||||
model_name: String,
|
||||
created_time: i64,
|
||||
output: Vec<ResponseOutputItem>,
|
||||
@@ -1336,11 +1479,26 @@ impl ResponsesResponse {
|
||||
id: request.request_id.clone(),
|
||||
object: "response".to_string(),
|
||||
created_at: created_time,
|
||||
status,
|
||||
error: None,
|
||||
incomplete_details: None,
|
||||
instructions: request.instructions.clone(),
|
||||
max_output_tokens: request.max_output_tokens,
|
||||
model: model_name,
|
||||
output,
|
||||
status,
|
||||
usage,
|
||||
parallel_tool_calls: request.parallel_tool_calls,
|
||||
previous_response_id: request.previous_response_id.clone(),
|
||||
reasoning: request.reasoning.as_ref().map(|r| ReasoningInfo {
|
||||
effort: r.effort.as_ref().map(|e| format!("{:?}", e)),
|
||||
summary: None,
|
||||
}),
|
||||
store: request.store,
|
||||
temperature: request.temperature,
|
||||
text: Some(ResponseTextFormat {
|
||||
format: TextFormatType {
|
||||
format_type: "text".to_string(),
|
||||
},
|
||||
}),
|
||||
tool_choice: match &request.tool_choice {
|
||||
ToolChoice::Value(ToolChoiceValue::Auto) => "auto".to_string(),
|
||||
ToolChoice::Value(ToolChoiceValue::Required) => "required".to_string(),
|
||||
@@ -1348,6 +1506,14 @@ impl ResponsesResponse {
|
||||
ToolChoice::Function { .. } => "function".to_string(),
|
||||
},
|
||||
tools: request.tools.clone(),
|
||||
top_p: request.top_p,
|
||||
truncation: match &request.truncation {
|
||||
Truncation::Auto => Some("auto".to_string()),
|
||||
Truncation::Disabled => Some("disabled".to_string()),
|
||||
},
|
||||
usage: usage.map(ResponsesUsage::Classic),
|
||||
user: request.user.clone(),
|
||||
metadata: request.metadata.clone().unwrap_or_default(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1357,13 +1523,26 @@ impl ResponsesResponse {
|
||||
id: request_id,
|
||||
object: "response".to_string(),
|
||||
created_at: current_timestamp(),
|
||||
status,
|
||||
error: None,
|
||||
incomplete_details: None,
|
||||
instructions: None,
|
||||
max_output_tokens: None,
|
||||
model,
|
||||
output: Vec::new(),
|
||||
status,
|
||||
usage: None,
|
||||
parallel_tool_calls: true,
|
||||
previous_response_id: None,
|
||||
reasoning: None,
|
||||
store: true,
|
||||
temperature: None,
|
||||
text: None,
|
||||
tool_choice: "auto".to_string(),
|
||||
tools: Vec::new(),
|
||||
top_p: None,
|
||||
truncation: None,
|
||||
usage: None,
|
||||
user: None,
|
||||
metadata: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1374,7 +1553,7 @@ impl ResponsesResponse {
|
||||
|
||||
/// Set the usage information
|
||||
pub fn set_usage(&mut self, usage: UsageInfo) {
|
||||
self.usage = Some(usage);
|
||||
self.usage = Some(ResponsesUsage::Classic(usage));
|
||||
}
|
||||
|
||||
/// Update the status
|
||||
@@ -1413,12 +1592,12 @@ impl ResponsesResponse {
|
||||
}
|
||||
|
||||
/// Get the response as a JSON value with usage in response format
|
||||
pub fn to_response_format(&self) -> serde_json::Value {
|
||||
let mut response = serde_json::to_value(self).unwrap_or(serde_json::Value::Null);
|
||||
pub fn to_response_format(&self) -> Value {
|
||||
let mut response = to_value(self).unwrap_or(Value::Null);
|
||||
|
||||
// Convert usage to response format if present
|
||||
if let Some(usage) = &self.usage {
|
||||
if let Ok(usage_value) = serde_json::to_value(usage.to_response_usage()) {
|
||||
if let Ok(usage_value) = to_value(usage.to_response_usage()) {
|
||||
response["usage"] = usage_value;
|
||||
}
|
||||
}
|
||||
@@ -1641,8 +1820,13 @@ pub struct LogProbs {
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct ChatLogProbs {
|
||||
pub content: Option<Vec<ChatLogProbsContent>>,
|
||||
#[serde(untagged)]
|
||||
pub enum ChatLogProbs {
|
||||
Detailed {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
content: Option<Vec<ChatLogProbsContent>>,
|
||||
},
|
||||
Raw(Value),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
@@ -1798,7 +1982,7 @@ pub struct GenerateRequest {
|
||||
|
||||
/// Session parameters for continual prompting
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub session_params: Option<HashMap<String, serde_json::Value>>,
|
||||
pub session_params: Option<HashMap<String, Value>>,
|
||||
|
||||
/// Return model hidden states
|
||||
#[serde(default)]
|
||||
@@ -2065,7 +2249,7 @@ pub struct EmbeddingRequest {
|
||||
pub model: String,
|
||||
|
||||
/// Input can be a string, array of strings, tokens, or batch inputs
|
||||
pub input: serde_json::Value,
|
||||
pub input: Value,
|
||||
|
||||
/// Optional encoding format (e.g., "float", "base64")
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
@@ -2097,8 +2281,8 @@ impl GenerationRequest for EmbeddingRequest {
|
||||
fn extract_text_for_routing(&self) -> String {
|
||||
// Best effort: extract text content for routing decisions
|
||||
match &self.input {
|
||||
serde_json::Value::String(s) => s.clone(),
|
||||
serde_json::Value::Array(arr) => arr
|
||||
Value::String(s) => s.clone(),
|
||||
Value::Array(arr) => arr
|
||||
.iter()
|
||||
.filter_map(|v| v.as_str())
|
||||
.collect::<Vec<_>>()
|
||||
@@ -2173,7 +2357,7 @@ pub enum LoRAPath {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use serde_json;
|
||||
use serde_json::{from_str, json, to_string};
|
||||
|
||||
// ==================================================================
|
||||
// = RERANK REQUEST TESTS =
|
||||
@@ -2191,8 +2375,8 @@ mod tests {
|
||||
user: Some("user-456".to_string()),
|
||||
};
|
||||
|
||||
let serialized = serde_json::to_string(&request).unwrap();
|
||||
let deserialized: RerankRequest = serde_json::from_str(&serialized).unwrap();
|
||||
let serialized = to_string(&request).unwrap();
|
||||
let deserialized: RerankRequest = from_str(&serialized).unwrap();
|
||||
|
||||
assert_eq!(deserialized.query, request.query);
|
||||
assert_eq!(deserialized.documents, request.documents);
|
||||
@@ -2210,7 +2394,7 @@ mod tests {
|
||||
"documents": ["doc1", "doc2"]
|
||||
}"#;
|
||||
|
||||
let request: RerankRequest = serde_json::from_str(json).unwrap();
|
||||
let request: RerankRequest = from_str(json).unwrap();
|
||||
|
||||
assert_eq!(request.query, "test query");
|
||||
assert_eq!(request.documents, vec!["doc1", "doc2"]);
|
||||
@@ -2402,8 +2586,8 @@ mod tests {
|
||||
Some(StringOrArray::String("req-123".to_string())),
|
||||
);
|
||||
|
||||
let serialized = serde_json::to_string(&response).unwrap();
|
||||
let deserialized: RerankResponse = serde_json::from_str(&serialized).unwrap();
|
||||
let serialized = to_string(&response).unwrap();
|
||||
let deserialized: RerankResponse = from_str(&serialized).unwrap();
|
||||
|
||||
assert_eq!(deserialized.results.len(), response.results.len());
|
||||
assert_eq!(deserialized.model, response.model);
|
||||
@@ -2539,13 +2723,13 @@ mod tests {
|
||||
("confidence".to_string(), Value::String("high".to_string())),
|
||||
(
|
||||
"processing_time".to_string(),
|
||||
Value::Number(serde_json::Number::from(150)),
|
||||
Value::Number(Number::from(150)),
|
||||
),
|
||||
])),
|
||||
};
|
||||
|
||||
let serialized = serde_json::to_string(&result).unwrap();
|
||||
let deserialized: RerankResult = serde_json::from_str(&serialized).unwrap();
|
||||
let serialized = to_string(&result).unwrap();
|
||||
let deserialized: RerankResult = from_str(&serialized).unwrap();
|
||||
|
||||
assert_eq!(deserialized.score, result.score);
|
||||
assert_eq!(deserialized.document, result.document);
|
||||
@@ -2562,8 +2746,8 @@ mod tests {
|
||||
meta_info: None,
|
||||
};
|
||||
|
||||
let serialized = serde_json::to_string(&result).unwrap();
|
||||
let deserialized: RerankResult = serde_json::from_str(&serialized).unwrap();
|
||||
let serialized = to_string(&result).unwrap();
|
||||
let deserialized: RerankResult = from_str(&serialized).unwrap();
|
||||
|
||||
assert_eq!(deserialized.score, result.score);
|
||||
assert_eq!(deserialized.document, result.document);
|
||||
@@ -2582,8 +2766,8 @@ mod tests {
|
||||
documents: vec!["doc1".to_string(), "doc2".to_string()],
|
||||
};
|
||||
|
||||
let serialized = serde_json::to_string(&v1_input).unwrap();
|
||||
let deserialized: V1RerankReqInput = serde_json::from_str(&serialized).unwrap();
|
||||
let serialized = to_string(&v1_input).unwrap();
|
||||
let deserialized: V1RerankReqInput = from_str(&serialized).unwrap();
|
||||
|
||||
assert_eq!(deserialized.query, v1_input.query);
|
||||
assert_eq!(deserialized.documents, v1_input.documents);
|
||||
@@ -2724,8 +2908,8 @@ mod tests {
|
||||
prompt_tokens_details: None,
|
||||
});
|
||||
|
||||
let serialized = serde_json::to_string(&response).unwrap();
|
||||
let deserialized: RerankResponse = serde_json::from_str(&serialized).unwrap();
|
||||
let serialized = to_string(&response).unwrap();
|
||||
let deserialized: RerankResponse = from_str(&serialized).unwrap();
|
||||
|
||||
assert!(deserialized.usage.is_some());
|
||||
let usage = deserialized.usage.unwrap();
|
||||
@@ -2805,8 +2989,8 @@ mod tests {
|
||||
assert_eq!(response.model, "rerank-model");
|
||||
|
||||
// Serialize and deserialize
|
||||
let serialized = serde_json::to_string(&response).unwrap();
|
||||
let deserialized: RerankResponse = serde_json::from_str(&serialized).unwrap();
|
||||
let serialized = to_string(&response).unwrap();
|
||||
let deserialized: RerankResponse = from_str(&serialized).unwrap();
|
||||
assert_eq!(deserialized.results.len(), 2);
|
||||
assert_eq!(deserialized.model, response.model);
|
||||
}
|
||||
@@ -2819,15 +3003,15 @@ mod tests {
|
||||
fn test_embedding_request_serialization_string_input() {
|
||||
let req = EmbeddingRequest {
|
||||
model: "test-emb".to_string(),
|
||||
input: serde_json::Value::String("hello".to_string()),
|
||||
input: Value::String("hello".to_string()),
|
||||
encoding_format: Some("float".to_string()),
|
||||
user: Some("user-1".to_string()),
|
||||
dimensions: Some(128),
|
||||
rid: Some("rid-123".to_string()),
|
||||
};
|
||||
|
||||
let serialized = serde_json::to_string(&req).unwrap();
|
||||
let deserialized: EmbeddingRequest = serde_json::from_str(&serialized).unwrap();
|
||||
let serialized = to_string(&req).unwrap();
|
||||
let deserialized: EmbeddingRequest = from_str(&serialized).unwrap();
|
||||
|
||||
assert_eq!(deserialized.model, req.model);
|
||||
assert_eq!(deserialized.input, req.input);
|
||||
@@ -2841,15 +3025,15 @@ mod tests {
|
||||
fn test_embedding_request_serialization_array_input() {
|
||||
let req = EmbeddingRequest {
|
||||
model: "test-emb".to_string(),
|
||||
input: serde_json::json!(["a", "b", "c"]),
|
||||
input: json!(["a", "b", "c"]),
|
||||
encoding_format: None,
|
||||
user: None,
|
||||
dimensions: None,
|
||||
rid: None,
|
||||
};
|
||||
|
||||
let serialized = serde_json::to_string(&req).unwrap();
|
||||
let de: EmbeddingRequest = serde_json::from_str(&serialized).unwrap();
|
||||
let serialized = to_string(&req).unwrap();
|
||||
let de: EmbeddingRequest = from_str(&serialized).unwrap();
|
||||
assert_eq!(de.model, req.model);
|
||||
assert_eq!(de.input, req.input);
|
||||
}
|
||||
@@ -2858,7 +3042,7 @@ mod tests {
|
||||
fn test_embedding_generation_request_trait_string() {
|
||||
let req = EmbeddingRequest {
|
||||
model: "emb-model".to_string(),
|
||||
input: serde_json::Value::String("hello".to_string()),
|
||||
input: Value::String("hello".to_string()),
|
||||
encoding_format: None,
|
||||
user: None,
|
||||
dimensions: None,
|
||||
@@ -2873,7 +3057,7 @@ mod tests {
|
||||
fn test_embedding_generation_request_trait_array() {
|
||||
let req = EmbeddingRequest {
|
||||
model: "emb-model".to_string(),
|
||||
input: serde_json::json!(["hello", "world"]),
|
||||
input: json!(["hello", "world"]),
|
||||
encoding_format: None,
|
||||
user: None,
|
||||
dimensions: None,
|
||||
@@ -2886,7 +3070,7 @@ mod tests {
|
||||
fn test_embedding_generation_request_trait_non_text() {
|
||||
let req = EmbeddingRequest {
|
||||
model: "emb-model".to_string(),
|
||||
input: serde_json::json!({"tokens": [1, 2, 3]}),
|
||||
input: json!({"tokens": [1, 2, 3]}),
|
||||
encoding_format: None,
|
||||
user: None,
|
||||
dimensions: None,
|
||||
@@ -2899,7 +3083,7 @@ mod tests {
|
||||
fn test_embedding_generation_request_trait_mixed_array_ignores_nested() {
|
||||
let req = EmbeddingRequest {
|
||||
model: "emb-model".to_string(),
|
||||
input: serde_json::json!(["a", ["b", "c"], 123, {"k": "v"}]),
|
||||
input: json!(["a", ["b", "c"], 123, {"k": "v"}]),
|
||||
encoding_format: None,
|
||||
user: None,
|
||||
dimensions: None,
|
||||
|
||||
Reference in New Issue
Block a user